Spaces:
Build error
Build error
yuyingge
commited on
Commit
•
590af54
1
Parent(s):
1264376
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- License.txt +335 -0
- configs/.DS_Store +0 -0
- configs/clm_models/.DS_Store +0 -0
- configs/clm_models/agent_seed_x_i.yaml +23 -0
- configs/clm_models/llm_seed_x_i.yaml +3 -0
- configs/discrete_model/.DS_Store +0 -0
- configs/discrete_model/discrete_identity.yaml +1 -0
- configs/processer/.DS_Store +0 -0
- configs/processer/qwen_448_transform.yaml +4 -0
- configs/sdxl_adapter/.DS_Store +0 -0
- configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml +20 -0
- configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml +18 -0
- configs/tokenizer/.DS_Store +0 -0
- configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml +2 -0
- configs/visual_encoder/.DS_Store +0 -0
- configs/visual_encoder/qwen_vitg_448.yaml +11 -0
- pretrained/QwenViT/qwen_vit_G.pt +3 -0
- requirements.txt +11 -0
- seed_x/arrow.jpg +0 -0
- seed_x/bank.png +0 -0
- src/.DS_Store +0 -0
- src/demo/__pycache__/conversation.cpython-311.pyc +0 -0
- src/demo/__pycache__/conversation.cpython-38.pyc +0 -0
- src/demo/__pycache__/utils.cpython-311.pyc +0 -0
- src/demo/__pycache__/utils.cpython-38.pyc +0 -0
- src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml +29 -0
- src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml +22 -0
- src/demo/configs/llama2chat13b_merged_100imgtokens.yaml +12 -0
- src/demo/conversation.py +182 -0
- src/demo/seed_llama_flask.py +379 -0
- src/demo/seed_llama_gradio.py +465 -0
- src/demo/utils.py +83 -0
- src/inference/.DS_Store +0 -0
- src/inference/__pycache__/any_res.cpython-311.pyc +0 -0
- src/inference/__pycache__/any_res.cpython-38.pyc +0 -0
- src/inference/any_res.py +257 -0
- src/inference/eval_img2edit_seed_x.py +155 -0
- src/inference/eval_img2text_seed_x.py +235 -0
- src/inference/eval_text2img_seed_x.py +94 -0
- src/models/detokenizer/__init__.py +1 -0
- src/models/detokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
- src/models/detokenizer/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc +0 -0
- src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc +0 -0
- src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc +0 -0
- src/models/detokenizer/__pycache__/resampler.cpython-311.pyc +0 -0
License.txt
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making Seed-X available.
|
2 |
+
|
3 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
4 |
+
|
5 |
+
Seed-X is licensed under the Apache License Version 2.0 except for the third-party components listed below.
|
6 |
+
|
7 |
+
|
8 |
+
Terms of the Apache License Version 2.0:
|
9 |
+
--------------------------------------------------------------------
|
10 |
+
Apache License
|
11 |
+
|
12 |
+
Version 2.0, January 2004
|
13 |
+
|
14 |
+
http://www.apache.org/licenses/
|
15 |
+
|
16 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
17 |
+
1. Definitions.
|
18 |
+
|
19 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
20 |
+
|
21 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
22 |
+
|
23 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
28 |
+
|
29 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
30 |
+
|
31 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
32 |
+
|
33 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
34 |
+
|
35 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
36 |
+
|
37 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
38 |
+
|
39 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
40 |
+
|
41 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
42 |
+
|
43 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
44 |
+
|
45 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
46 |
+
|
47 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
48 |
+
|
49 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
50 |
+
|
51 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
52 |
+
|
53 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
54 |
+
|
55 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
56 |
+
|
57 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
58 |
+
|
59 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
60 |
+
|
61 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
62 |
+
|
63 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
64 |
+
|
65 |
+
END OF TERMS AND CONDITIONS
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
Other dependencies and licenses:
|
70 |
+
|
71 |
+
|
72 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
73 |
+
--------------------------------------------------------------------
|
74 |
+
1. transformers
|
75 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
76 |
+
Source code of this software can be obtained from: https://github.com/huggingface/transformers/blob/v4.30.2/
|
77 |
+
|
78 |
+
2. diffusers
|
79 |
+
Copyright 2023 The HuggingFace Team. All rights reserved.
|
80 |
+
Source code of this software can be obtained from: https://github.com/huggingface/diffusers/blob/v0.25.0/
|
81 |
+
|
82 |
+
A copy of Apache 2.0 has been included in this file.
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
87 |
+
--------------------------------------------------------------------
|
88 |
+
1. torchvision
|
89 |
+
Copyright (c) Soumith Chintala 2016,
|
90 |
+
All rights reserved.
|
91 |
+
|
92 |
+
Terms of the BSD 3-Clause License:
|
93 |
+
--------------------------------------------------------------------
|
94 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
95 |
+
|
96 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
97 |
+
|
98 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
99 |
+
|
100 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
101 |
+
|
102 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
107 |
+
--------------------------------------------------------------------
|
108 |
+
1. numpy
|
109 |
+
Copyright (c) 2005-2021, NumPy Developers.
|
110 |
+
All rights reserved.
|
111 |
+
|
112 |
+
A copy of the BSD 3-Clause License is included in this file.
|
113 |
+
|
114 |
+
For the license of other third party components, please refer to the following URL:
|
115 |
+
https://github.com/numpy/numpy/blob/v1.20.1/LICENSES_bundled.txt
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
120 |
+
--------------------------------------------------------------------
|
121 |
+
1. torch
|
122 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
123 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
124 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
125 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
126 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
127 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
128 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
129 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
130 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
131 |
+
|
132 |
+
A copy of the BSD 3-Clause License is included in this file.
|
133 |
+
|
134 |
+
For the license of other third party components, please refer to the following URL:
|
135 |
+
https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
Open Source Software Licensed under the LLAMA 2 Community License:
|
140 |
+
--------------------------------------------------------------------
|
141 |
+
1. Llama 2
|
142 |
+
Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
143 |
+
|
144 |
+
|
145 |
+
Terms of the LLAMA 2 COMMUNITY LICENSE AGREEMENT:
|
146 |
+
--------------------------------------------------------------------
|
147 |
+
LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
148 |
+
Llama 2 Version Release Date: July 18, 2023
|
149 |
+
|
150 |
+
"Agreement" means the terms and conditions for use, reproduction, distribution and
|
151 |
+
modification of the Llama Materials set forth herein.
|
152 |
+
|
153 |
+
"Documentation" means the specifications, manuals and documentation
|
154 |
+
accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
|
155 |
+
libraries/llama-downloads/.
|
156 |
+
|
157 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if
|
158 |
+
you are entering into this Agreement on such person or entity's behalf), of the age
|
159 |
+
required under applicable laws, rules or regulations to provide legal consent and that
|
160 |
+
has legal authority to bind your employer or such other person or entity if you are
|
161 |
+
entering in this Agreement on their behalf.
|
162 |
+
|
163 |
+
"Llama 2" means the foundational large language models and software and
|
164 |
+
algorithms, including machine-learning model code, trained model weights,
|
165 |
+
inference-enabling code, training-enabling code, fine-tuning enabling code and other
|
166 |
+
elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
|
167 |
+
libraries/llama-downloads/.
|
168 |
+
|
169 |
+
"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
|
170 |
+
Documentation (and any portion thereof) made available under this Agreement.
|
171 |
+
|
172 |
+
"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
|
173 |
+
are an entity, your principal place of business is in the EEA or Switzerland) and Meta
|
174 |
+
Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
175 |
+
|
176 |
+
By clicking "I Accept" below or by using or distributing any portion or element of the
|
177 |
+
Llama Materials, you agree to be bound by this Agreement.
|
178 |
+
|
179 |
+
1. License Rights and Redistribution.
|
180 |
+
|
181 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
|
182 |
+
transferable and royalty-free limited license under Meta's intellectual property or
|
183 |
+
other rights owned by Meta embodied in the Llama Materials to use, reproduce,
|
184 |
+
distribute, copy, create derivative works of, and make modifications to the Llama
|
185 |
+
Materials.
|
186 |
+
|
187 |
+
b. Redistribution and Use.
|
188 |
+
|
189 |
+
i. If you distribute or make the Llama Materials, or any derivative works
|
190 |
+
thereof, available to a third party, you shall provide a copy of this Agreement to such
|
191 |
+
third party.
|
192 |
+
ii. If you receive Llama Materials, or any derivative works thereof, from
|
193 |
+
a Licensee as part of an integrated end user product, then Section 2 of this
|
194 |
+
Agreement will not apply to you.
|
195 |
+
|
196 |
+
iii. You must retain in all copies of the Llama Materials that you
|
197 |
+
distribute the following attribution notice within a "Notice" text file distributed as a
|
198 |
+
part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
|
199 |
+
Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
200 |
+
|
201 |
+
iv. Your use of the Llama Materials must comply with applicable laws
|
202 |
+
and regulations (including trade compliance laws and regulations) and adhere to the
|
203 |
+
Acceptable Use Policy for the Llama Materials (available at
|
204 |
+
https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
|
205 |
+
this Agreement.
|
206 |
+
|
207 |
+
v. You will not use the Llama Materials or any output or results of the
|
208 |
+
Llama Materials to improve any other large language model (excluding Llama 2 or
|
209 |
+
derivative works thereof).
|
210 |
+
|
211 |
+
2. Additional Commercial Terms. If, on the Llama 2 version release date, the
|
212 |
+
monthly active users of the products or services made available by or for Licensee,
|
213 |
+
or Licensee's affiliates, is greater than 700 million monthly active users in the
|
214 |
+
preceding calendar month, you must request a license from Meta, which Meta may
|
215 |
+
grant to you in its sole discretion, and you are not authorized to exercise any of the
|
216 |
+
rights under this Agreement unless or until Meta otherwise expressly grants you
|
217 |
+
such rights.
|
218 |
+
|
219 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
|
220 |
+
LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
|
221 |
+
PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
222 |
+
EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
|
223 |
+
WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
|
224 |
+
FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
|
225 |
+
FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
|
226 |
+
THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
|
227 |
+
USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
|
228 |
+
|
229 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
|
230 |
+
LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
|
231 |
+
NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
|
232 |
+
AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
|
233 |
+
CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
|
234 |
+
IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
|
235 |
+
ANY OF THE FOREGOING.
|
236 |
+
|
237 |
+
5. Intellectual Property.
|
238 |
+
|
239 |
+
a. No trademark licenses are granted under this Agreement, and in
|
240 |
+
connection with the Llama Materials, neither Meta nor Licensee may use any name
|
241 |
+
or mark owned by or associated with the other or any of its affiliates, except as
|
242 |
+
required for reasonable and customary use in describing and redistributing the
|
243 |
+
Llama Materials.
|
244 |
+
|
245 |
+
b. Subject to Meta's ownership of Llama Materials and derivatives made by or
|
246 |
+
for Meta, with respect to any derivative works and modifications of the Llama
|
247 |
+
Materials that are made by you, as between you and Meta, you are and will be the
|
248 |
+
owner of such derivative works and modifications.
|
249 |
+
|
250 |
+
c. If you institute litigation or other proceedings against Meta or any entity
|
251 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
|
252 |
+
Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
|
253 |
+
constitutes infringement of intellectual property or other rights owned or licensable
|
254 |
+
by you, then any licenses granted to you under this Agreement shall terminate as of
|
255 |
+
the date such litigation or claim is filed or instituted. You will indemnify and hold
|
256 |
+
harmless Meta from and against any claim by any third party arising out of or related
|
257 |
+
to your use or distribution of the Llama Materials.
|
258 |
+
|
259 |
+
6. Term and Termination. The term of this Agreement will commence upon your
|
260 |
+
acceptance of this Agreement or access to the Llama Materials and will continue in
|
261 |
+
full force and effect until terminated in accordance with the terms and conditions
|
262 |
+
herein. Meta may terminate this Agreement if you are in breach of any term or
|
263 |
+
condition of this Agreement. Upon termination of this Agreement, you shall delete
|
264 |
+
and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
|
265 |
+
termination of this Agreement.
|
266 |
+
|
267 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and
|
268 |
+
construed under the laws of the State of California without regard to choice of law
|
269 |
+
principles, and the UN Convention on Contracts for the International Sale of Goods
|
270 |
+
does not apply to this Agreement. The courts of California shall have exclusive
|
271 |
+
jurisdiction of any dispute arising out of this Agreement.
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
Open Source Software Licensed under the Tongyi Qianwen LICENSE AGREEMENT:
|
276 |
+
--------------------------------------------------------------------
|
277 |
+
1. Qwen-VL
|
278 |
+
Copyright (c) Alibaba Cloud. All Rights Reserved.
|
279 |
+
|
280 |
+
|
281 |
+
Terms of the Tongyi Qianwen LICENSE AGREEMENT:
|
282 |
+
--------------------------------------------------------------------
|
283 |
+
Tongyi Qianwen LICENSE AGREEMENT
|
284 |
+
|
285 |
+
Tongyi Qianwen Release Date: August 23, 2023
|
286 |
+
|
287 |
+
By clicking to agree or by using or distributing any portion or element of the Tongyi Qianwen Materials, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
288 |
+
|
289 |
+
1. Definitions
|
290 |
+
a. This Tongyi Qianwen LICENSE AGREEMENT (this "Agreement") shall mean the terms and conditions for use, reproduction, distribution and modification of the Materials as defined by this Agreement.
|
291 |
+
b. "We"(or "Us") shall mean Alibaba Cloud.
|
292 |
+
c. "You" (or "Your") shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Materials for any purpose and in any field of use.
|
293 |
+
d. "Third Parties" shall mean individuals or legal entities that are not under common control with Us or You.
|
294 |
+
e. "Tongyi Qianwen" shall mean the large language models (including Qwen-VL model and Qwen-VL-Chat model), and software and algorithms, consisting of trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing distributed by Us.
|
295 |
+
f. "Materials" shall mean, collectively, Alibaba Cloud's proprietary Tongyi Qianwen and Documentation (and any portion thereof) made available under this Agreement.
|
296 |
+
g. "Source" form shall mean the preferred form for making modifications, including but not limited to model source code, documentation source, and configuration files.
|
297 |
+
h. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation,
|
298 |
+
and conversions to other media types.
|
299 |
+
|
300 |
+
2. Grant of Rights
|
301 |
+
You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Alibaba Cloud's intellectual property or other rights owned by Us embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials.
|
302 |
+
|
303 |
+
3. Redistribution
|
304 |
+
You may reproduce and distribute copies of the Materials or derivative works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
305 |
+
a. You shall give any other recipients of the Materials or derivative works a copy of this Agreement;
|
306 |
+
b. You shall cause any modified files to carry prominent notices stating that You changed the files;
|
307 |
+
c. You shall retain in all copies of the Materials that You distribute the following attribution notices within a "Notice" text file distributed as a part of such copies: "Tongyi Qianwen is licensed under the Tongyi Qianwen LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved."; and
|
308 |
+
d. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such derivative works as a whole, provided Your use, reproduction, and distribution of the work otherwise complies with the terms and conditions of this Agreement.
|
309 |
+
|
310 |
+
4. Restrictions
|
311 |
+
If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us. You cannot exercise your rights under this Agreement without our express authorization.
|
312 |
+
|
313 |
+
5. Rules of use
|
314 |
+
a. The Materials may be subject to export controls or restrictions in China, the United States or other countries or regions. You shall comply with applicable laws and regulations in your use of the Materials.
|
315 |
+
b. You can not use the Materials or any output therefrom to improve any other large language model (excluding Tongyi Qianwen or derivative works thereof).
|
316 |
+
|
317 |
+
6. Intellectual Property
|
318 |
+
a. We retain ownership of all intellectual property rights in and to the Materials and derivatives made by or for Us. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by you, you are and will be the owner of such derivative works and modifications.
|
319 |
+
b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
|
320 |
+
c. If you commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any entity alleging that the Materials or any output therefrom, or any part of the foregoing, infringe any intellectual property or other right owned or licensable by you, then all licences granted to you under this Agreement shall terminate as of the date such lawsuit or other proceeding is commenced or brought.
|
321 |
+
|
322 |
+
7. Disclaimer of Warranty and Limitation of Liability
|
323 |
+
|
324 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tongyi Qianwen Materials or to grant any license thereto.
|
325 |
+
b. THE MATERIALS ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
|
326 |
+
c. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
|
327 |
+
d. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
|
328 |
+
|
329 |
+
8. Survival and Termination.
|
330 |
+
a. The term of this Agreement shall commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
331 |
+
b. We may terminate this Agreement if you breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, you must delete and cease use of the Materials. Sections 7 and 9 shall survive the termination of this Agreement.
|
332 |
+
|
333 |
+
9. Governing Law and Jurisdiction.
|
334 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
335 |
+
b. The People's Courts in Hangzhou City shall have exclusive jurisdiction over any dispute arising out of this Agreement.
|
configs/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
configs/clm_models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/clm_models/agent_seed_x_i.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.mllm.seed_x.ContinuousLVLM.from_pretrained
|
2 |
+
input_resampler:
|
3 |
+
_target_: src.models.tokenizer.qwen_visual.Resampler
|
4 |
+
grid_size: 8
|
5 |
+
embed_dim: 5120
|
6 |
+
num_heads: 32
|
7 |
+
kv_dim: 4096
|
8 |
+
|
9 |
+
output_resampler:
|
10 |
+
_target_: src.models.tokenizer.qwen_visual.Resampler
|
11 |
+
grid_size: 8
|
12 |
+
embed_dim: 4096
|
13 |
+
num_heads: 32
|
14 |
+
kv_dim: 5120
|
15 |
+
|
16 |
+
add_patch_pos: True
|
17 |
+
vit_down: True
|
18 |
+
mse: True
|
19 |
+
|
20 |
+
lm_loss_scale: 1.0
|
21 |
+
rec_loss_scale: 6.0
|
22 |
+
|
23 |
+
pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_x_i/agent/pytorch_model.bin
|
configs/clm_models/llm_seed_x_i.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.mllm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained
|
2 |
+
pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/seed_x_i/llm
|
3 |
+
low_cpu_mem_usage: True
|
configs/discrete_model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/discrete_model/discrete_identity.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
_target_: src.models.tokenizer.discrete_models.DiscreteModleIdentity
|
configs/processer/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/processer/qwen_448_transform.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.processer.transforms.get_transform
|
2 |
+
type: clip
|
3 |
+
image_size: 448
|
4 |
+
keep_ratio: False
|
configs/sdxl_adapter/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.detokenizer.adapter_modules.SDXLAdapterWithLatentImage.from_pretrained
|
2 |
+
|
3 |
+
resampler:
|
4 |
+
_target_: src.models.detokenizer.resampler.ResamplerXLV2
|
5 |
+
dim: 1024
|
6 |
+
depth: 4
|
7 |
+
dim_head: 64
|
8 |
+
heads: 16
|
9 |
+
num_queries: 64
|
10 |
+
embedding_dim: 4096
|
11 |
+
output1_dim: 768
|
12 |
+
output2_dim: 1280
|
13 |
+
ff_mult: 4
|
14 |
+
normalize: False
|
15 |
+
|
16 |
+
full_ft: True
|
17 |
+
set_trainable_late: False
|
18 |
+
|
19 |
+
vit_down: True
|
20 |
+
pretrained_model_path: pretrained/seed_detokenizer/second_stage/pytorch_model.bin
|
configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.detokenizer.adapter_modules.SDXLAdapter.from_pretrained
|
2 |
+
|
3 |
+
resampler:
|
4 |
+
_target_: src.models.detokenizer.resampler.ResamplerXLV2
|
5 |
+
dim: 1024
|
6 |
+
depth: 4
|
7 |
+
dim_head: 64
|
8 |
+
heads: 16
|
9 |
+
num_queries: 64
|
10 |
+
embedding_dim: 4096
|
11 |
+
output1_dim: 768
|
12 |
+
output2_dim: 1280
|
13 |
+
ff_mult: 4
|
14 |
+
normalize: False
|
15 |
+
|
16 |
+
vit_down: True
|
17 |
+
|
18 |
+
pretrained_model_path: https://huggingface.co/AILab-CVC/SEED-X-17B/blob/main/seed_detokenizer/first_stage/pytorch_model.bin
|
configs/tokenizer/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
_target_: transformers.LlamaTokenizer.from_pretrained
|
2 |
+
pretrained_model_name_or_path: https://huggingface.co/AILab-CVC/SEED-X-17B/tree/main/cvlm_llama2_tokenizer_100img_and_224loc_addpatch
|
configs/visual_encoder/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/visual_encoder/qwen_vitg_448.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.tokenizer.qwen_visual.VisionTransformerWithAttnPool.from_pretrained
|
2 |
+
heads: 16
|
3 |
+
image_size: 448
|
4 |
+
image_start_id": 151857
|
5 |
+
layers: 48
|
6 |
+
mlp_ratio: 4.9231
|
7 |
+
output_dim: 4096
|
8 |
+
patch_size: 14
|
9 |
+
width: 1664
|
10 |
+
|
11 |
+
pretrained_model_path: pretrained/QwenViT/qwen_vit_G.pt
|
pretrained/QwenViT/qwen_vit_G.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d951083fc79b07bdb84be61944eb263b8e14572fe2dc4fa80b0447f83064463c
|
3 |
+
size 3871440281
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
hydra-core
|
3 |
+
transformers==4.30.2
|
4 |
+
diffusers==0.25.0
|
5 |
+
sentencepiece
|
6 |
+
opencv-python
|
7 |
+
deepspeed
|
8 |
+
pyrootutils
|
9 |
+
xformers>=0.0.20
|
10 |
+
accelerate
|
11 |
+
transformers_stream_generator
|
seed_x/arrow.jpg
ADDED
seed_x/bank.png
ADDED
src/.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
src/demo/__pycache__/conversation.cpython-311.pyc
ADDED
Binary file (8.21 kB). View file
|
|
src/demo/__pycache__/conversation.cpython-38.pyc
ADDED
Binary file (4.43 kB). View file
|
|
src/demo/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (4.39 kB). View file
|
|
src/demo/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.32 kB). View file
|
|
src/demo/configs/agent_13b_anyres_out_64_pretrain_merged.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models_clm.models.ContinuousLVLM.from_pretrained
|
2 |
+
input_resampler:
|
3 |
+
_target_: src.models.qwen_visual.Resampler
|
4 |
+
grid_size: 8
|
5 |
+
embed_dim: 5120
|
6 |
+
num_heads: 32
|
7 |
+
kv_dim: 4096
|
8 |
+
|
9 |
+
output_resampler:
|
10 |
+
_target_: src.models.qwen_visual.Resampler
|
11 |
+
grid_size: 8
|
12 |
+
embed_dim: 4096
|
13 |
+
num_heads: 32
|
14 |
+
kv_dim: 5120
|
15 |
+
|
16 |
+
add_patch_pos: True
|
17 |
+
vit_down: True
|
18 |
+
mse: True
|
19 |
+
|
20 |
+
lm_loss_scale: 1.0
|
21 |
+
rec_loss_scale: 6.0
|
22 |
+
|
23 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/sft_exp_new_acc4/checkpoint-2000/pytorch_model.bin
|
24 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-9000/pytorch_model.bin
|
25 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/agent/pytorch_model.bin
|
26 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/agent/pytorch_model.bin
|
27 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/agent/pytorch_model.bin
|
28 |
+
#pretrained_model_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/agent/pytorch_model.bin
|
29 |
+
pretrained_model_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/agent/pytorch_model.bin
|
src/demo/configs/agent_13b_in100_out64_rs5_merged_pretrain.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models_clm.models.ContinuousLVLM.from_pretrained
|
2 |
+
input_resampler:
|
3 |
+
_target_: src.models.qwen_visual.Resampler
|
4 |
+
grid_size: 10
|
5 |
+
embed_dim: 5120
|
6 |
+
num_heads: 32
|
7 |
+
kv_dim: 4096
|
8 |
+
|
9 |
+
output_resampler:
|
10 |
+
_target_: src.models.qwen_visual.Resampler
|
11 |
+
grid_size: 16
|
12 |
+
embed_dim: 4096
|
13 |
+
num_heads: 32
|
14 |
+
kv_dim: 5120
|
15 |
+
|
16 |
+
lm_loss_scale: 1.0
|
17 |
+
rec_loss_scale: 5.0
|
18 |
+
|
19 |
+
# pretrained_model_path: /chat_sh/share_300719895/user/sijiezhao/Program/2023/DiscreteLearning/train_output_clm_sh/1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro/checkpoint-27000/pytorch_model.bin
|
20 |
+
# pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/agent/pytorch_model.bin
|
21 |
+
# pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/agent/pytorch_model.bin
|
22 |
+
pretrained_model_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1211_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_grounding_27k/ckpt-4000-merged/agent/pytorch_model.bin
|
src/demo/configs/llama2chat13b_merged_100imgtokens.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
_target_: src.models_clm.modeling_llama_xformer.LlamaForCausalLM.from_pretrained
|
3 |
+
# _target_: transformers.LlamaForCausalLM.from_pretrained
|
4 |
+
# pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-10000-merged/llm
|
5 |
+
# pretrained_model_name_or_path: /apdcephfs_cq4/share_2942043/Multimodal/sijiezhao/DiscreteLearning/train_output/sft_from_1208_llama2chat13b_lora_clm_qwen-vit-448_pretrain_rs5_64a100pro_40k/ckpt-5000-merged/llm
|
6 |
+
#pretrained_model_name_or_path: /chat_sh/share_300719895/user/jinguozhu/codes/work_dirs/pretraining_anyres_newexp_v2/checkpoint-10000-merged/llm
|
7 |
+
#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/03_27_any_res_sft_from_merged_10k/checkpoint-8000-merged/llm
|
8 |
+
#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_09_any_res_sft_editing_from_merged_10k_32a100_new_data/checkpoint-6000-merged/llm
|
9 |
+
#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_editing_from_merged_H800_23k_16_gpu_2_new/checkpoint-6000-merged/llm
|
10 |
+
#pretrained_model_name_or_path: /chat_sh/share_300719895/user/yuyingge/jinguo_code/DiscreteLearning_debug/train_output/04_16_any_res_sft_com_gen_from_merged_H800_23k/checkpoint-15000-merged/llm
|
11 |
+
pretrained_model_name_or_path: /group/40034/yuyingge/SEED_X_inference/pretrained/seed_x_i/llm
|
12 |
+
low_cpu_mem_usage: True
|
src/demo/conversation.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import copy
|
10 |
+
|
11 |
+
IMG_FLAG = '<image>'
|
12 |
+
|
13 |
+
|
14 |
+
class SeparatorStyle(Enum):
|
15 |
+
"""Different separator style."""
|
16 |
+
SINGLE = auto()
|
17 |
+
TWO = auto()
|
18 |
+
MPT = auto()
|
19 |
+
PLAIN = auto()
|
20 |
+
LLAMA_2 = auto()
|
21 |
+
|
22 |
+
|
23 |
+
def decode_image(encoded_image: str) -> Image:
|
24 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
25 |
+
buffer = io.BytesIO(decoded_bytes)
|
26 |
+
image = Image.open(buffer)
|
27 |
+
return image
|
28 |
+
|
29 |
+
|
30 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
31 |
+
with io.BytesIO() as buffer:
|
32 |
+
image.save(buffer, format=format)
|
33 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
34 |
+
return encoded_image
|
35 |
+
|
36 |
+
|
37 |
+
@dataclasses.dataclass
|
38 |
+
class Conversation:
|
39 |
+
"""A class that keeps all conversation history."""
|
40 |
+
system: str
|
41 |
+
roles: List[str]
|
42 |
+
messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
|
43 |
+
offset: int
|
44 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
45 |
+
sep: str = "###"
|
46 |
+
sep2: str = None
|
47 |
+
version: str = "Unknown"
|
48 |
+
|
49 |
+
skip_next: bool = False
|
50 |
+
|
51 |
+
def get_prompt(self):
|
52 |
+
messages = copy.deepcopy(self.messages)
|
53 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
54 |
+
if self.system is None or self.system == '':
|
55 |
+
text = ''
|
56 |
+
else:
|
57 |
+
text = self.system + self.sep
|
58 |
+
images = []
|
59 |
+
for message in messages:
|
60 |
+
text += message['role'] + ": " + message['message']['text'] + self.sep
|
61 |
+
for image_path in message['message']['images']:
|
62 |
+
image = Image.open(image_path).resize((256, 256))
|
63 |
+
image_base64 = encode_image(image)
|
64 |
+
images.append(image_base64)
|
65 |
+
|
66 |
+
text += self.roles[1] + ":"
|
67 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
68 |
+
b_token = "[INST] "
|
69 |
+
e_token = " [/INST]"
|
70 |
+
if self.system is None or self.system == '':
|
71 |
+
text = ''
|
72 |
+
else:
|
73 |
+
text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
|
74 |
+
images = []
|
75 |
+
for idx, message in enumerate(messages):
|
76 |
+
# text += message['role'] + ": " + message['message']['text'] + self.sep
|
77 |
+
if idx % 2 == 0:
|
78 |
+
text += b_token + message['message']['text'] + e_token + self.sep
|
79 |
+
else:
|
80 |
+
text += message['message']['text'] + self.sep
|
81 |
+
|
82 |
+
for image_path in message['message']['images']:
|
83 |
+
image = Image.open(image_path)
|
84 |
+
image_base64 = encode_image(image)
|
85 |
+
images.append(image_base64)
|
86 |
+
else:
|
87 |
+
raise NotImplementedError
|
88 |
+
|
89 |
+
return {'text': text, 'images': images}
|
90 |
+
|
91 |
+
# def update_image_ids(self, images_ids):
|
92 |
+
# image_count = 0
|
93 |
+
# for message in self.messages:
|
94 |
+
# for idx in range(len(message['message']['images_ids'])):
|
95 |
+
# if message['message']["images_ids"][idx] is None:
|
96 |
+
# message['message']["images_ids"][idx] = images_ids[image_count]
|
97 |
+
# image_count += 1
|
98 |
+
|
99 |
+
# assert len(images_ids) == image_count, print(len(images_ids), image_count)
|
100 |
+
|
101 |
+
def append_message(self, role, message):
|
102 |
+
self.messages.append([role, message])
|
103 |
+
|
104 |
+
def to_gradio_chatbot(self):
|
105 |
+
dialog = []
|
106 |
+
for i, single_turn in enumerate(self.messages[self.offset:]):
|
107 |
+
single_turn = single_turn['message']
|
108 |
+
text_list = single_turn['text'].split(IMG_FLAG)
|
109 |
+
assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
|
110 |
+
message = ''
|
111 |
+
for image_idx in range(len(single_turn['images'])):
|
112 |
+
# image = single_turn['images'][image_idx]
|
113 |
+
# image_base64 = encode_image(image)
|
114 |
+
# image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
|
115 |
+
image_path = single_turn['images'][image_idx]
|
116 |
+
if image_path == '':
|
117 |
+
message += text_list[image_idx] + '<corrupt_image>'
|
118 |
+
else:
|
119 |
+
message += text_list[image_idx] + f'![](file={image_path})'
|
120 |
+
message += text_list[-1]
|
121 |
+
|
122 |
+
if i % 2 == 0:
|
123 |
+
dialog.append([message, None])
|
124 |
+
else:
|
125 |
+
dialog[-1][-1] = message
|
126 |
+
|
127 |
+
return dialog
|
128 |
+
|
129 |
+
def copy(self):
|
130 |
+
return Conversation(system=self.system,
|
131 |
+
roles=self.roles,
|
132 |
+
messages=copy.deepcopy(self.messages),
|
133 |
+
offset=self.offset,
|
134 |
+
sep_style=self.sep_style,
|
135 |
+
sep=self.sep,
|
136 |
+
sep2=self.sep2,
|
137 |
+
version=self.version)
|
138 |
+
|
139 |
+
def dict(self):
|
140 |
+
messages = copy.deepcopy(self.messages)
|
141 |
+
for message in messages:
|
142 |
+
for i in range(len(message['message']['images'])):
|
143 |
+
message['message']['images'][i] = os.path.basename(message['message']['images'][i])
|
144 |
+
return {
|
145 |
+
"system": self.system,
|
146 |
+
"roles": self.roles,
|
147 |
+
"messages": messages,
|
148 |
+
"offset": self.offset,
|
149 |
+
"sep": self.sep,
|
150 |
+
"sep2": self.sep2,
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
conv_seed_vicuna = Conversation(
|
155 |
+
system="",
|
156 |
+
roles=("USER", "ASSISTANT"),
|
157 |
+
version="v2",
|
158 |
+
messages=[],
|
159 |
+
offset=0,
|
160 |
+
sep_style=SeparatorStyle.SINGLE,
|
161 |
+
sep='\n',
|
162 |
+
)
|
163 |
+
|
164 |
+
conv_seed_vicuna_system = Conversation(
|
165 |
+
system="A chat between a curious user and an artificial intelligence assistant. ",
|
166 |
+
roles=("USER", "ASSISTANT"),
|
167 |
+
version="v2",
|
168 |
+
messages=[],
|
169 |
+
offset=0,
|
170 |
+
sep_style=SeparatorStyle.SINGLE,
|
171 |
+
sep='\n',
|
172 |
+
)
|
173 |
+
|
174 |
+
conv_seed_llama2 = Conversation(
|
175 |
+
system="",
|
176 |
+
roles=("[INST]", "[/INST]"),
|
177 |
+
version="v2",
|
178 |
+
messages=[],
|
179 |
+
offset=0,
|
180 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
181 |
+
sep='\n',
|
182 |
+
)
|
src/demo/seed_llama_flask.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import pyrootutils
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from flask import Flask, request
|
8 |
+
from typing import Optional
|
9 |
+
import transformers
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
import io
|
12 |
+
import base64
|
13 |
+
from PIL import Image
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
|
17 |
+
|
18 |
+
|
19 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
20 |
+
|
21 |
+
from src.data.any_res import process_anyres_image
|
22 |
+
|
23 |
+
BOI_TOKEN = '<img>'
|
24 |
+
BOP_TOKEN = '<patch>'
|
25 |
+
EOI_TOKEN = '</img>'
|
26 |
+
EOP_TOKEN = '</patch>'
|
27 |
+
IMG_TOKEN = '<img_{:05d}>'
|
28 |
+
|
29 |
+
IMG_FLAG = '<image>'
|
30 |
+
num_img_in_tokens = 64
|
31 |
+
num_img_out_tokens = 64
|
32 |
+
|
33 |
+
resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2']
|
34 |
+
base_resolution = 448
|
35 |
+
|
36 |
+
app = Flask(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
def decode_image(encoded_image: str) -> Image:
|
40 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
41 |
+
buffer = io.BytesIO(decoded_bytes)
|
42 |
+
image = Image.open(buffer)
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
47 |
+
with io.BytesIO() as buffer:
|
48 |
+
image.save(buffer, format=format)
|
49 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
50 |
+
return encoded_image
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class Arguments:
|
55 |
+
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
|
56 |
+
tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
57 |
+
llm: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
58 |
+
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
59 |
+
sd_adapter: Optional[str] = field(default=None, metadata={"help": "config path of sd adapter"})
|
60 |
+
agent: Optional[str] = field(default=None, metadata={"help": "config path of agent model"})
|
61 |
+
diffusion_path: Optional[str] = field(default=None, metadata={"help": "diffusion model path"})
|
62 |
+
has_bbox: Optional[bool] = field(default=False, metadata={"help": "visualize the box"})
|
63 |
+
|
64 |
+
port: Optional[str] = field(default=80, metadata={"help": "network port"})
|
65 |
+
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
|
66 |
+
vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"})
|
67 |
+
dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"})
|
68 |
+
|
69 |
+
multi_resolution: Optional[bool] = field(default=False, metadata={"help": "multi resolution"})
|
70 |
+
|
71 |
+
|
72 |
+
parser = transformers.HfArgumentParser(Arguments)
|
73 |
+
args, = parser.parse_args_into_dataclasses()
|
74 |
+
|
75 |
+
def extract_box(output_str):
|
76 |
+
boxes = re.findall('(.*?)<box_end>', output_str)
|
77 |
+
if len(boxes) >0:
|
78 |
+
bboxes = [[int(num) for num in re.findall('<loc-(\d+)>', box)] for box in boxes]
|
79 |
+
else:
|
80 |
+
bboxes = None
|
81 |
+
|
82 |
+
return bboxes
|
83 |
+
|
84 |
+
|
85 |
+
def visualize_bbox(image, bboxes):
|
86 |
+
img_width, img_height = image.size
|
87 |
+
image = np.array(image)
|
88 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
89 |
+
for bbox in bboxes:
|
90 |
+
x_center, y_center, box_width, box_height = bbox
|
91 |
+
|
92 |
+
x_center = x_center / 224 * img_width
|
93 |
+
y_center = y_center / 224 * img_height
|
94 |
+
|
95 |
+
box_width = box_width /224 * img_width
|
96 |
+
box_height = box_height / 224 * img_height
|
97 |
+
|
98 |
+
x1 = int(x_center - box_width / 2)
|
99 |
+
y1 = int(y_center - box_height / 2)
|
100 |
+
x2 = int(x_center + box_width / 2)
|
101 |
+
y2 = int(y_center + box_height / 2)
|
102 |
+
|
103 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 4)
|
104 |
+
|
105 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
106 |
+
image = Image.fromarray(image)
|
107 |
+
|
108 |
+
|
109 |
+
return image
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
class LLMService:
|
115 |
+
|
116 |
+
def __init__(self, args) -> None:
|
117 |
+
|
118 |
+
self.llm_device = args.llm_device
|
119 |
+
self.vit_sd_device = args.vit_sd_device
|
120 |
+
|
121 |
+
dtype = args.dtype
|
122 |
+
if dtype == 'fp16':
|
123 |
+
self.dtype = torch.float16
|
124 |
+
elif dtype == 'bf16':
|
125 |
+
self.dtype = torch.bfloat16
|
126 |
+
else:
|
127 |
+
raise ValueError
|
128 |
+
|
129 |
+
image_transform_cfg = OmegaConf.load(args.image_transform)
|
130 |
+
self.image_transform = hydra.utils.instantiate(image_transform_cfg)
|
131 |
+
|
132 |
+
tokenizer_cfg = OmegaConf.load(args.tokenizer)
|
133 |
+
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
134 |
+
|
135 |
+
visual_encoder_cfg = OmegaConf.load(args.visual_encoder)
|
136 |
+
self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
137 |
+
self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype)
|
138 |
+
print('Init visual encoder done')
|
139 |
+
|
140 |
+
llm_cfg = OmegaConf.load(args.llm)
|
141 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype)
|
142 |
+
print('Init llm done.')
|
143 |
+
|
144 |
+
agent_cfg = OmegaConf.load(args.agent)
|
145 |
+
self.agent = hydra.utils.instantiate(agent_cfg, llm=llm)
|
146 |
+
|
147 |
+
self.agent.eval().to(self.llm_device, dtype=self.dtype)
|
148 |
+
print('Init agent mdoel Done')
|
149 |
+
|
150 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler")
|
151 |
+
|
152 |
+
vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype)
|
153 |
+
|
154 |
+
unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype)
|
155 |
+
|
156 |
+
sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
|
157 |
+
|
158 |
+
self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype)
|
159 |
+
|
160 |
+
self.sd_adapter.init_pipe(vae=vae,
|
161 |
+
scheduler=noise_scheduler,
|
162 |
+
visual_encoder=self.visual_encoder.to("cpu"),
|
163 |
+
image_transform=self.image_transform,
|
164 |
+
discrete_model=None,
|
165 |
+
dtype=self.dtype,
|
166 |
+
device="cpu")
|
167 |
+
|
168 |
+
print('Init sd adapter pipe done.')
|
169 |
+
|
170 |
+
self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype)
|
171 |
+
|
172 |
+
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
173 |
+
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
174 |
+
|
175 |
+
self.bop_token_id = self.tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
|
176 |
+
self.eop_token_id = self.tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
|
177 |
+
|
178 |
+
self.multi_resolution = args.multi_resolution
|
179 |
+
if self.multi_resolution:
|
180 |
+
self.base_resolution = base_resolution
|
181 |
+
grid_pinpoints = []
|
182 |
+
for scale in resolution_grids:
|
183 |
+
s1, s2 = scale.split('x')
|
184 |
+
grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
|
185 |
+
self.grid_pinpoints = grid_pinpoints
|
186 |
+
|
187 |
+
|
188 |
+
service = LLMService(args)
|
189 |
+
|
190 |
+
|
191 |
+
@app.route('/generate', methods=['GET', 'POST'])
|
192 |
+
def generate():
|
193 |
+
with torch.no_grad():
|
194 |
+
request_info = request.get_json()
|
195 |
+
|
196 |
+
text_list = request_info['text'].split(IMG_FLAG)
|
197 |
+
image_list = request_info['images']
|
198 |
+
max_new_tokens = request_info.get('max_new_tokens', 256)
|
199 |
+
top_p = 0.5
|
200 |
+
force_boi = request_info.get('force_boi', False)
|
201 |
+
force_bbox = request_info.get('force_bbox', False)
|
202 |
+
|
203 |
+
assert len(text_list) == len(image_list) + 1
|
204 |
+
|
205 |
+
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
206 |
+
|
207 |
+
input_images = []
|
208 |
+
if len(image_list) > 0:
|
209 |
+
image_tensor_list = []
|
210 |
+
embeds_cmp_mask = []
|
211 |
+
embeds_gen_mask = []
|
212 |
+
|
213 |
+
if service.multi_resolution:
|
214 |
+
patch_pos = []
|
215 |
+
image_patch_length = []
|
216 |
+
image_size_list = []
|
217 |
+
|
218 |
+
for idx, image_item in enumerate(image_list):
|
219 |
+
if isinstance(image_item, str):
|
220 |
+
image = decode_image(image_item)
|
221 |
+
print('after decode image size:', image.size)
|
222 |
+
input_images.append(image)
|
223 |
+
|
224 |
+
if service.multi_resolution:
|
225 |
+
image_size_list.append(image.size)
|
226 |
+
print('image size:', image.size)
|
227 |
+
image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, service.grid_pinpoints, service.base_resolution)
|
228 |
+
image_tensor_list.append(image_tensor)
|
229 |
+
patch_pos.append(patch_pos_tensor)
|
230 |
+
image_patch_length.append(image_tensor.shape[0])
|
231 |
+
print('image_patch_length', image_patch_length)
|
232 |
+
embeds_cmp_mask.extend([True]*image_tensor.shape[0])
|
233 |
+
embeds_gen_mask.extend([False]*image_tensor.shape[0])
|
234 |
+
|
235 |
+
else:
|
236 |
+
image_tensor = service.image_transform(image)
|
237 |
+
|
238 |
+
image_tensor_list.append(image_tensor)
|
239 |
+
embeds_cmp_mask.append(True)
|
240 |
+
embeds_gen_mask.append(False)
|
241 |
+
else:
|
242 |
+
raise ValueError
|
243 |
+
|
244 |
+
if service.multi_resolution:
|
245 |
+
pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
|
246 |
+
patch_position = torch.cat(patch_pos, dim=0)
|
247 |
+
|
248 |
+
image_tokens_list = []
|
249 |
+
for patch_length in image_patch_length:
|
250 |
+
image_tokens = ''
|
251 |
+
for _ in range(patch_length-1):
|
252 |
+
image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
|
253 |
+
image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
|
254 |
+
image_tokens_list.append(image_tokens)
|
255 |
+
else:
|
256 |
+
pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
|
257 |
+
|
258 |
+
image_embeds = service.visual_encoder(pixel_values)
|
259 |
+
image_embeds = image_embeds.to(service.llm_device)
|
260 |
+
|
261 |
+
embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device)
|
262 |
+
embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device)
|
263 |
+
|
264 |
+
else:
|
265 |
+
image_embeds = None
|
266 |
+
patch_position = 0
|
267 |
+
embeds_cmp_mask = None
|
268 |
+
embeds_gen_mask = None
|
269 |
+
|
270 |
+
if service.multi_resolution:
|
271 |
+
input_text = ''
|
272 |
+
for i, c in enumerate(text_list[:-1]):
|
273 |
+
input_text += c + image_tokens_list[i]
|
274 |
+
input_text += text_list[-1]
|
275 |
+
|
276 |
+
else:
|
277 |
+
input_text = image_tokens.join(text_list)
|
278 |
+
|
279 |
+
if force_boi:
|
280 |
+
input_text = input_text + BOI_TOKEN
|
281 |
+
|
282 |
+
if force_bbox:
|
283 |
+
input_text = input_text + '[[ <box_start>'
|
284 |
+
print('input_text:', input_text)
|
285 |
+
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False)
|
286 |
+
input_ids = [service.tokenizer.bos_token_id] + input_ids
|
287 |
+
|
288 |
+
input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long)
|
289 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
|
290 |
+
ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
|
291 |
+
|
292 |
+
if service.multi_resolution:
|
293 |
+
boi_indices = torch.where(torch.logical_or(input_ids == service.boi_token_id, input_ids == service.bop_token_id))[0].tolist()
|
294 |
+
eoi_indices = torch.where(torch.logical_or(input_ids == service.eoi_token_id, input_ids == service.eop_token_id))[0].tolist()
|
295 |
+
|
296 |
+
else:
|
297 |
+
|
298 |
+
boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist()
|
299 |
+
eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist()
|
300 |
+
|
301 |
+
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
|
302 |
+
ids_cmp_mask[boi_idx + 1:eoi_idx] = True
|
303 |
+
|
304 |
+
input_ids = input_ids.unsqueeze(0)
|
305 |
+
ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
|
306 |
+
ids_gen_mask = ids_gen_mask.unsqueeze(0)
|
307 |
+
|
308 |
+
error_msg = []
|
309 |
+
|
310 |
+
if service.multi_resolution:
|
311 |
+
output = service.agent.generate(
|
312 |
+
tokenizer=service.tokenizer,
|
313 |
+
input_ids=input_ids,
|
314 |
+
image_embeds=image_embeds,
|
315 |
+
patch_positions=patch_position,
|
316 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
317 |
+
ids_cmp_mask=ids_cmp_mask,
|
318 |
+
num_img_gen_tokens=num_img_out_tokens,
|
319 |
+
max_new_tokens=max_new_tokens,
|
320 |
+
dtype=service.dtype,
|
321 |
+
device=service.llm_device,
|
322 |
+
top_p=top_p,
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
output = service.agent.generate(
|
326 |
+
tokenizer=service.tokenizer,
|
327 |
+
input_ids=input_ids,
|
328 |
+
image_embeds=image_embeds,
|
329 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
330 |
+
ids_cmp_mask=ids_cmp_mask,
|
331 |
+
num_img_gen_tokens=num_img_out_tokens,
|
332 |
+
max_new_tokens=max_new_tokens,
|
333 |
+
dtype=service.dtype,
|
334 |
+
device=service.llm_device,
|
335 |
+
top_p=top_p,
|
336 |
+
)
|
337 |
+
|
338 |
+
gen_imgs_base64_list = []
|
339 |
+
generated_text = output['text']
|
340 |
+
generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
|
341 |
+
|
342 |
+
if output['has_img_output']:
|
343 |
+
print('loading visual encoder and llm to CPU, and sd to GPU')
|
344 |
+
a = time.time()
|
345 |
+
service.agent = service.agent.to("cpu")
|
346 |
+
service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
|
347 |
+
print("Loading finished: ", time.time() - a)
|
348 |
+
|
349 |
+
img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
|
350 |
+
|
351 |
+
for img_idx in range(output['num_gen_imgs']):
|
352 |
+
img_feat = img_gen_feat[img_idx:img_idx + 1]
|
353 |
+
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
|
354 |
+
image_base64 = encode_image(generated_image)
|
355 |
+
gen_imgs_base64_list.append(image_base64)
|
356 |
+
|
357 |
+
print('loading visual encoder and llm to GPU, and sd to CPU')
|
358 |
+
a = time.time()
|
359 |
+
service.sd_adapter = service.sd_adapter.to("cpu")
|
360 |
+
service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
|
361 |
+
service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
|
362 |
+
print("Loading finished: ", time.time() - a)
|
363 |
+
|
364 |
+
if args.has_bbox:
|
365 |
+
bboxes = extract_box(generated_text)
|
366 |
+
|
367 |
+
if bboxes is not None and len(input_images) > 0:
|
368 |
+
image_viz = visualize_bbox(input_images[0], bboxes)
|
369 |
+
image_base64 = encode_image(image_viz)
|
370 |
+
gen_imgs_base64_list.append(image_base64)
|
371 |
+
generated_text = re.sub(r'\[\[ <box_start>.*?<box_end>.*?\]\]', 'the green bounding box', generated_text)
|
372 |
+
generated_text += IMG_FLAG
|
373 |
+
print(input_text + generated_text)
|
374 |
+
|
375 |
+
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
|
376 |
+
|
377 |
+
|
378 |
+
if __name__ == '__main__':
|
379 |
+
app.run(host='0.0.0.0', port=args.port)
|
src/demo/seed_llama_gradio.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
from typing import Optional
|
6 |
+
import transformers
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
import io
|
9 |
+
import base64
|
10 |
+
from PIL import Image
|
11 |
+
import gradio as gr
|
12 |
+
import time
|
13 |
+
import hashlib
|
14 |
+
import requests
|
15 |
+
|
16 |
+
from utils import build_logger
|
17 |
+
from conversation import conv_seed_llama2
|
18 |
+
|
19 |
+
IMG_FLAG = '<image>'
|
20 |
+
LOGDIR = 'log'
|
21 |
+
|
22 |
+
logger = build_logger("gradio_seed_x", LOGDIR)
|
23 |
+
headers = {"User-Agent": "SEED-X Client"}
|
24 |
+
|
25 |
+
no_change_btn = gr.Button.update()
|
26 |
+
enable_btn = gr.Button.update(interactive=True)
|
27 |
+
disable_btn = gr.Button.update(interactive=False)
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Arguments:
|
32 |
+
server_port: Optional[int] = field(default=7860, metadata={"help": "network port"})
|
33 |
+
server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"})
|
34 |
+
request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate',
|
35 |
+
metadata={"help": "request address"})
|
36 |
+
|
37 |
+
|
38 |
+
parser = transformers.HfArgumentParser(Arguments)
|
39 |
+
args, = parser.parse_args_into_dataclasses()
|
40 |
+
conv_seed_llama = conv_seed_llama2
|
41 |
+
|
42 |
+
|
43 |
+
def decode_image(encoded_image: str) -> Image:
|
44 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
45 |
+
buffer = io.BytesIO(decoded_bytes)
|
46 |
+
image = Image.open(buffer)
|
47 |
+
return image
|
48 |
+
|
49 |
+
|
50 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
51 |
+
with io.BytesIO() as buffer:
|
52 |
+
image.save(buffer, format=format)
|
53 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
54 |
+
return encoded_image
|
55 |
+
|
56 |
+
|
57 |
+
def get_conv_log_filename():
|
58 |
+
t = datetime.datetime.now()
|
59 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
60 |
+
return name
|
61 |
+
|
62 |
+
|
63 |
+
def get_conv_image_dir():
|
64 |
+
name = os.path.join(LOGDIR, 'images')
|
65 |
+
os.makedirs(name, exist_ok=True)
|
66 |
+
return name
|
67 |
+
|
68 |
+
|
69 |
+
def get_image_name(image, image_dir=None):
|
70 |
+
buffer = io.BytesIO()
|
71 |
+
image.save(buffer, format='PNG')
|
72 |
+
image_bytes = buffer.getvalue()
|
73 |
+
md5 = hashlib.md5(image_bytes).hexdigest()
|
74 |
+
|
75 |
+
if image_dir is not None:
|
76 |
+
image_name = os.path.join(image_dir, md5 + '.png')
|
77 |
+
else:
|
78 |
+
image_name = md5 + '.png'
|
79 |
+
|
80 |
+
return image_name
|
81 |
+
|
82 |
+
|
83 |
+
def resize_image_square(image, target_size=448):
|
84 |
+
resized_image = image.resize((target_size, target_size))
|
85 |
+
return resized_image
|
86 |
+
|
87 |
+
|
88 |
+
def resize_image(image, max_size=512):
|
89 |
+
width, height = image.size
|
90 |
+
aspect_ratio = float(width) / float(height)
|
91 |
+
|
92 |
+
if width > height:
|
93 |
+
new_width = max_size
|
94 |
+
new_height = int(new_width / aspect_ratio)
|
95 |
+
else:
|
96 |
+
new_height = max_size
|
97 |
+
new_width = int(new_height * aspect_ratio)
|
98 |
+
|
99 |
+
resized_image = image.resize((new_width, new_height))
|
100 |
+
return resized_image
|
101 |
+
|
102 |
+
|
103 |
+
def center_crop_image(image, max_aspect_ratio=1.5):
|
104 |
+
width, height = image.size
|
105 |
+
aspect_ratio = max(width, height) / min(width, height)
|
106 |
+
|
107 |
+
if aspect_ratio >= max_aspect_ratio:
|
108 |
+
if width > height:
|
109 |
+
new_width = int(height * max_aspect_ratio)
|
110 |
+
left = (width - new_width) // 2
|
111 |
+
right = (width + new_width) // 2
|
112 |
+
top = 0
|
113 |
+
bottom = height
|
114 |
+
else:
|
115 |
+
new_height = int(width * max_aspect_ratio)
|
116 |
+
left = 0
|
117 |
+
right = width
|
118 |
+
top = (height - new_height) // 2
|
119 |
+
bottom = (height + new_height) // 2
|
120 |
+
|
121 |
+
cropped_image = image.crop((left, top, right, bottom))
|
122 |
+
return cropped_image
|
123 |
+
else:
|
124 |
+
return image
|
125 |
+
|
126 |
+
|
127 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
128 |
+
with open(get_conv_log_filename(), "a") as fout:
|
129 |
+
data = {
|
130 |
+
"tstamp": round(time.time(), 4),
|
131 |
+
"type": vote_type,
|
132 |
+
"state": state.dict(),
|
133 |
+
"ip": request.client.host,
|
134 |
+
}
|
135 |
+
fout.write(json.dumps(data) + "\n")
|
136 |
+
|
137 |
+
|
138 |
+
def upvote_last_response(state, request: gr.Request):
|
139 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
140 |
+
vote_last_response(state, "upvote", request)
|
141 |
+
return (disable_btn,) * 2
|
142 |
+
|
143 |
+
|
144 |
+
def downvote_last_response(state, request: gr.Request):
|
145 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
146 |
+
vote_last_response(state, "downvote", request)
|
147 |
+
return (disable_btn,) * 2
|
148 |
+
|
149 |
+
|
150 |
+
def regenerate(dialog_state, request: gr.Request):
|
151 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
152 |
+
if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
|
153 |
+
dialog_state.messages.pop()
|
154 |
+
return (
|
155 |
+
dialog_state,
|
156 |
+
dialog_state.to_gradio_chatbot(),
|
157 |
+
) + (disable_btn,) * 4
|
158 |
+
|
159 |
+
|
160 |
+
def clear_history(request: gr.Request):
|
161 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
162 |
+
dialog_state = conv_seed_llama.copy()
|
163 |
+
input_state = init_input_state()
|
164 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
165 |
+
|
166 |
+
|
167 |
+
def init_input_state():
|
168 |
+
return {'images': [], 'text': ''}
|
169 |
+
|
170 |
+
|
171 |
+
def add_text(dialog_state, input_state, text, request: gr.Request):
|
172 |
+
logger.info(f"add_text. ip: {request.client.host}.")
|
173 |
+
# if len(input_state['text']) == 0:
|
174 |
+
if text is None or len(text) == 0:
|
175 |
+
# dialog_state.skip_next = True
|
176 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
177 |
+
input_state['text'] += text
|
178 |
+
|
179 |
+
|
180 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
181 |
+
dialog_state.messages[-1]['message'] = input_state
|
182 |
+
else:
|
183 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
184 |
+
print('add_text: ', dialog_state.to_gradio_chatbot())
|
185 |
+
|
186 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
187 |
+
|
188 |
+
|
189 |
+
def is_blank(image):
|
190 |
+
image_array = np.array(image)
|
191 |
+
unique_colors = np.unique(image_array)
|
192 |
+
print('unique_colors', len(unique_colors))
|
193 |
+
return len(unique_colors) == 1
|
194 |
+
|
195 |
+
|
196 |
+
def add_image(dialog_state, input_state, image, request: gr.Request):
|
197 |
+
logger.info(f"add_image. ip: {request.client.host}.")
|
198 |
+
if image is None:
|
199 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
200 |
+
|
201 |
+
image = image.convert('RGB')
|
202 |
+
|
203 |
+
print('image size:', image.size)
|
204 |
+
|
205 |
+
image = center_crop_image(image, max_aspect_ratio=10)
|
206 |
+
|
207 |
+
image_dir = get_conv_image_dir()
|
208 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
209 |
+
if not os.path.exists(image_path):
|
210 |
+
image.save(image_path)
|
211 |
+
input_state['images'].append(image_path)
|
212 |
+
input_state['text'] += IMG_FLAG
|
213 |
+
|
214 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
215 |
+
dialog_state.messages[-1]['message'] = input_state
|
216 |
+
else:
|
217 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
218 |
+
|
219 |
+
print('add_image:', dialog_state)
|
220 |
+
|
221 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
222 |
+
|
223 |
+
|
224 |
+
def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox,
|
225 |
+
request: gr.Request):
|
226 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
227 |
+
print('input_state:', input_state)
|
228 |
+
|
229 |
+
if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
|
230 |
+
dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
|
231 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
232 |
+
|
233 |
+
if len(dialog_state.messages) > max_turns * 2:
|
234 |
+
output_state = init_input_state()
|
235 |
+
output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
|
236 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
237 |
+
input_state = init_input_state()
|
238 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,)
|
239 |
+
|
240 |
+
prompt = dialog_state.get_prompt()
|
241 |
+
payload = {
|
242 |
+
'text': prompt['text'],
|
243 |
+
'max_new_tokens': int(max_new_tokens),
|
244 |
+
'images': prompt['images'],
|
245 |
+
'force_boi': force_image_gen,
|
246 |
+
'force_bbox': force_bbox,
|
247 |
+
}
|
248 |
+
|
249 |
+
print(
|
250 |
+
'request: ', {
|
251 |
+
'text': prompt['text'],
|
252 |
+
'max_new_tokens': int(max_new_tokens),
|
253 |
+
})
|
254 |
+
print('request_address', args.request_address)
|
255 |
+
response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload)
|
256 |
+
results = response.json()
|
257 |
+
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
|
258 |
+
|
259 |
+
output_state = init_input_state()
|
260 |
+
image_dir = get_conv_image_dir()
|
261 |
+
output_state['text'] = results['text']
|
262 |
+
|
263 |
+
for image_base64 in results['images']:
|
264 |
+
if image_base64 == '':
|
265 |
+
image_path = ''
|
266 |
+
else:
|
267 |
+
image = decode_image(image_base64)
|
268 |
+
image = image.convert('RGB')
|
269 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
270 |
+
if not os.path.exists(image_path):
|
271 |
+
image.save(image_path)
|
272 |
+
output_state['images'].append(image_path)
|
273 |
+
|
274 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
275 |
+
|
276 |
+
vote_last_response(dialog_state, 'common', request)
|
277 |
+
input_state = init_input_state()
|
278 |
+
chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
|
279 |
+
return (dialog_state, input_state, chatbot) + (enable_btn,) * 4
|
280 |
+
|
281 |
+
|
282 |
+
def update_error_msg(chatbot, error_msg):
|
283 |
+
if len(error_msg) > 0:
|
284 |
+
info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
|
285 |
+
error_msg)
|
286 |
+
chatbot[-1][-1] = chatbot[-1][-1] + info
|
287 |
+
|
288 |
+
return chatbot
|
289 |
+
|
290 |
+
|
291 |
+
def load_demo(request: gr.Request):
|
292 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
293 |
+
dialog_state = conv_seed_llama.copy()
|
294 |
+
input_state = init_input_state()
|
295 |
+
return dialog_state, input_state
|
296 |
+
|
297 |
+
|
298 |
+
title = ("""
|
299 |
+
# SEED-X-I
|
300 |
+
[[Paper]](https://arxiv.org/abs/2404.14396) [[Code]](https://github.com/AILab-CVC/SEED-X)
|
301 |
+
|
302 |
+
Demo of a general instruction-tuned model SEED-X-I (17B) from the foundation model SEED-X.
|
303 |
+
|
304 |
+
SEED-X-I can follow multimodal instruction (including images with **dynamic resolutions**) and make responses with **images, texts and bounding boxes** in multi-turn conversation.
|
305 |
+
|
306 |
+
SEED-X-I **does not support image manipulation**. If you want to experience **SEED-X-Edit** for high-precision image editing, please refer to [[Inference Code]](https://github.com/AILab-CVC/SEED-X).
|
307 |
+
|
308 |
+
Due to insufficient GPU memory, when generating images, we need to offload the LLM to the CPU and move the de-tokenizer to the CPU, which will **result in a long processing time**. If you want to experience the normal model inference speed, you can run [[Inference Code]](https://github.com/AILab-CVC/SEED-X) locally.
|
309 |
+
|
310 |
+
|
311 |
+
## Tips:
|
312 |
+
* Check out the conversation examples (at the bottom) for inspiration.
|
313 |
+
|
314 |
+
* You can adjust "Max History Rounds" to try a conversation with up to five rounds. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
|
315 |
+
|
316 |
+
* Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
|
317 |
+
|
318 |
+
* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
|
319 |
+
|
320 |
+
* You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
|
321 |
+
|
322 |
+
* SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
|
323 |
+
|
324 |
+
""")
|
325 |
+
|
326 |
+
css = """
|
327 |
+
img {
|
328 |
+
font-family: 'Helvetica';
|
329 |
+
font-weight: 300;
|
330 |
+
line-height: 2;
|
331 |
+
text-align: center;
|
332 |
+
|
333 |
+
width: auto;
|
334 |
+
height: auto;
|
335 |
+
display: block;
|
336 |
+
position: relative;
|
337 |
+
}
|
338 |
+
|
339 |
+
img:before {
|
340 |
+
content: " ";
|
341 |
+
display: block;
|
342 |
+
|
343 |
+
position: absolute;
|
344 |
+
top: -10px;
|
345 |
+
left: 0;
|
346 |
+
height: calc(100% + 10px);
|
347 |
+
width: 100%;
|
348 |
+
background-color: rgb(230, 230, 230);
|
349 |
+
border: 2px dotted rgb(200, 200, 200);
|
350 |
+
border-radius: 5px;
|
351 |
+
}
|
352 |
+
|
353 |
+
img:after {
|
354 |
+
content: " ";
|
355 |
+
display: block;
|
356 |
+
font-size: 16px;
|
357 |
+
font-style: normal;
|
358 |
+
font-family: FontAwesome;
|
359 |
+
color: rgb(100, 100, 100);
|
360 |
+
|
361 |
+
position: absolute;
|
362 |
+
top: 5px;
|
363 |
+
left: 0;
|
364 |
+
width: 100%;
|
365 |
+
text-align: center;
|
366 |
+
}
|
367 |
+
|
368 |
+
"""
|
369 |
+
|
370 |
+
if __name__ == '__main__':
|
371 |
+
|
372 |
+
examples_mix = [
|
373 |
+
['seed_x/bank.png', 'Can I conntect with an advisor on Sunday?'],
|
374 |
+
['seed_x/ground.png',
|
375 |
+
'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'],
|
376 |
+
['seed_x/arrow.jpg', 'What is the object pointed by the red arrow?'],
|
377 |
+
['seed_x/shanghai.png', 'Where was this image taken? Explain your answer.'],
|
378 |
+
['seed_x/GPT4.png', 'How long does it take to make GPT-4 safer?'],
|
379 |
+
['seed_x/twitter.png',
|
380 |
+
'Please provide a comprehensive description of this image.'],
|
381 |
+
]
|
382 |
+
|
383 |
+
examples_text = [
|
384 |
+
['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'],
|
385 |
+
['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'],
|
386 |
+
[
|
387 |
+
'Can you design an illustration for “The Three-Body Problem” to depict a scene from the novel? Show me a picture.'],
|
388 |
+
[
|
389 |
+
'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'],
|
390 |
+
[
|
391 |
+
'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'],
|
392 |
+
['Generate an impressionist painting of an astronaut in a jungle.']
|
393 |
+
]
|
394 |
+
with gr.Blocks(css=css) as demo:
|
395 |
+
gr.Markdown(title)
|
396 |
+
dialog_state = gr.State()
|
397 |
+
input_state = gr.State()
|
398 |
+
with gr.Row():
|
399 |
+
with gr.Column(scale=3):
|
400 |
+
with gr.Row():
|
401 |
+
image = gr.Image(type='pil', label='input_image')
|
402 |
+
with gr.Row():
|
403 |
+
text = gr.Textbox(lines=5,
|
404 |
+
show_label=False,
|
405 |
+
label='input_text',
|
406 |
+
elem_id='textbox',
|
407 |
+
placeholder="Enter text or add image, and press submit,").style(container=False)
|
408 |
+
with gr.Row():
|
409 |
+
add_image_btn = gr.Button("Add Image")
|
410 |
+
add_text_btn = gr.Button("Add Text")
|
411 |
+
|
412 |
+
submit_btn = gr.Button("Submit")
|
413 |
+
|
414 |
+
with gr.Row():
|
415 |
+
max_new_tokens = gr.Slider(minimum=64,
|
416 |
+
maximum=1024,
|
417 |
+
value=768,
|
418 |
+
step=64,
|
419 |
+
interactive=True,
|
420 |
+
label="Max Output Tokens")
|
421 |
+
max_turns = gr.Slider(minimum=1, maximum=9, value=3, step=1, interactive=True,
|
422 |
+
label="Max History Rounds")
|
423 |
+
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
|
424 |
+
force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
|
425 |
+
|
426 |
+
with gr.Column(scale=7):
|
427 |
+
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I").style(height=700)
|
428 |
+
with gr.Row():
|
429 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
430 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
431 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
432 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
433 |
+
|
434 |
+
with gr.Row():
|
435 |
+
with gr.Column(scale=0.7):
|
436 |
+
gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text])
|
437 |
+
with gr.Column(scale=0.3):
|
438 |
+
gr.Examples(examples=examples_text, label='Input examples', inputs=[text])
|
439 |
+
|
440 |
+
# Register listeners
|
441 |
+
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
|
442 |
+
upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
443 |
+
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
444 |
+
|
445 |
+
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
|
446 |
+
http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
|
447 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
448 |
+
add_image_btn.click(add_image, [dialog_state, input_state, image],
|
449 |
+
[dialog_state, input_state, image, chatbot] + btn_list)
|
450 |
+
|
451 |
+
add_text_btn.click(add_text, [dialog_state, input_state, text],
|
452 |
+
[dialog_state, input_state, text, chatbot] + btn_list)
|
453 |
+
|
454 |
+
submit_btn.click(
|
455 |
+
add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
|
456 |
+
add_text, [dialog_state, input_state, text],
|
457 |
+
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
|
458 |
+
http_bot,
|
459 |
+
[dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
|
460 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
461 |
+
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
|
462 |
+
|
463 |
+
demo.load(load_demo, None, [dialog_state, input_state])
|
464 |
+
|
465 |
+
demo.launch(server_name=args.server_name, server_port=args.server_port, enable_queue=True)
|
src/demo/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
handler = None
|
8 |
+
|
9 |
+
|
10 |
+
def build_logger(logger_name, logger_dir):
|
11 |
+
global handler
|
12 |
+
|
13 |
+
formatter = logging.Formatter(
|
14 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16 |
+
)
|
17 |
+
|
18 |
+
# Set the format of root handlers
|
19 |
+
if not logging.getLogger().handlers:
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
22 |
+
|
23 |
+
# Redirect stdout and stderr to loggers
|
24 |
+
stdout_logger = logging.getLogger("stdout")
|
25 |
+
stdout_logger.setLevel(logging.INFO)
|
26 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
27 |
+
sys.stdout = sl
|
28 |
+
|
29 |
+
stderr_logger = logging.getLogger("stderr")
|
30 |
+
stderr_logger.setLevel(logging.ERROR)
|
31 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
32 |
+
sys.stderr = sl
|
33 |
+
|
34 |
+
# Get logger
|
35 |
+
logger = logging.getLogger(logger_name)
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
# Add a file handler for all loggers
|
39 |
+
if handler is None:
|
40 |
+
os.makedirs(logger_dir, exist_ok=True)
|
41 |
+
filename = os.path.join(logger_dir, logger_name + '.log')
|
42 |
+
handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
|
43 |
+
handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
for name, item in logging.root.manager.loggerDict.items():
|
46 |
+
if isinstance(item, logging.Logger):
|
47 |
+
item.addHandler(handler)
|
48 |
+
|
49 |
+
return logger
|
50 |
+
|
51 |
+
|
52 |
+
class StreamToLogger(object):
|
53 |
+
"""
|
54 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, logger, log_level=logging.INFO):
|
58 |
+
self.terminal = sys.stdout
|
59 |
+
self.logger = logger
|
60 |
+
self.log_level = log_level
|
61 |
+
self.linebuf = ''
|
62 |
+
|
63 |
+
def __getattr__(self, attr):
|
64 |
+
return getattr(self.terminal, attr)
|
65 |
+
|
66 |
+
def write(self, buf):
|
67 |
+
temp_linebuf = self.linebuf + buf
|
68 |
+
self.linebuf = ''
|
69 |
+
for line in temp_linebuf.splitlines(True):
|
70 |
+
# From the io.TextIOWrapper docs:
|
71 |
+
# On output, if newline is None, any '\n' characters written
|
72 |
+
# are translated to the system default line separator.
|
73 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
74 |
+
# translates them so this is still cross platform.
|
75 |
+
if line[-1] == '\n':
|
76 |
+
self.logger.log(self.log_level, line.rstrip())
|
77 |
+
else:
|
78 |
+
self.linebuf += line
|
79 |
+
|
80 |
+
def flush(self):
|
81 |
+
if self.linebuf != '':
|
82 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
83 |
+
self.linebuf = ''
|
src/inference/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/inference/__pycache__/any_res.cpython-311.pyc
ADDED
Binary file (12.2 kB). View file
|
|
src/inference/__pycache__/any_res.cpython-38.pyc
ADDED
Binary file (7.47 kB). View file
|
|
src/inference/any_res.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import ast
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
|
9 |
+
def select_best_resolution(original_size, possible_resolutions):
|
10 |
+
"""
|
11 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
15 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
tuple: The best fit resolution in the format (width, height).
|
19 |
+
"""
|
20 |
+
original_width, original_height = original_size
|
21 |
+
best_fit = None
|
22 |
+
max_effective_resolution = 0
|
23 |
+
min_wasted_resolution = float('inf')
|
24 |
+
|
25 |
+
for width, height in possible_resolutions:
|
26 |
+
scale = min(width / original_width, height / original_height)
|
27 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
28 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
29 |
+
wasted_resolution = (width * height) - effective_resolution
|
30 |
+
|
31 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
32 |
+
max_effective_resolution = effective_resolution
|
33 |
+
min_wasted_resolution = wasted_resolution
|
34 |
+
best_fit = (width, height)
|
35 |
+
|
36 |
+
return best_fit
|
37 |
+
|
38 |
+
|
39 |
+
def select_best_resolution_v2(original_size, possible_resolutions):
|
40 |
+
"""
|
41 |
+
Selects the best resolution from a list of possible resolutions based on the original size and aspect ratio.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
45 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tuple: The best fit resolution in the format (width, height).
|
49 |
+
"""
|
50 |
+
original_width, original_height = original_size
|
51 |
+
original_aspect_ratio = original_height / original_width
|
52 |
+
original_area = original_width * original_height
|
53 |
+
best_fit = None
|
54 |
+
min_aspect_ratio_diff = float('inf')
|
55 |
+
min_area_ratio = float('inf')
|
56 |
+
|
57 |
+
for width, height in possible_resolutions:
|
58 |
+
aspect_ratio = height / width
|
59 |
+
area = width * height
|
60 |
+
aspect_ratio_diff = max(aspect_ratio, original_aspect_ratio) / min(aspect_ratio, original_aspect_ratio)
|
61 |
+
area_ratio = max(area, original_area) / min(area, original_area)
|
62 |
+
|
63 |
+
if aspect_ratio_diff < min_aspect_ratio_diff or (aspect_ratio_diff == min_aspect_ratio_diff and area_ratio < min_area_ratio):
|
64 |
+
min_aspect_ratio_diff = aspect_ratio_diff
|
65 |
+
min_area_ratio = area_ratio
|
66 |
+
best_fit = (width, height)
|
67 |
+
|
68 |
+
return best_fit
|
69 |
+
|
70 |
+
|
71 |
+
def resize_and_pad_image(image, target_resolution, keep_ratio=False):
|
72 |
+
"""
|
73 |
+
Resize and pad an image to a target resolution
|
74 |
+
|
75 |
+
Args:
|
76 |
+
image (PIL.Image.Image): The input image.
|
77 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
PIL.Image.Image: The resized and padded image.
|
81 |
+
"""
|
82 |
+
original_width, original_height = image.size
|
83 |
+
target_width, target_height = target_resolution
|
84 |
+
|
85 |
+
if keep_ratio:
|
86 |
+
# maintaining aspect ratio
|
87 |
+
scale_w = target_width / original_width
|
88 |
+
scale_h = target_height / original_height
|
89 |
+
|
90 |
+
if scale_w < scale_h:
|
91 |
+
new_width = target_width
|
92 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
93 |
+
else:
|
94 |
+
new_height = target_height
|
95 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
96 |
+
|
97 |
+
# Resize the image
|
98 |
+
resized_image = image.resize((new_width, new_height))
|
99 |
+
|
100 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
101 |
+
paste_x = (target_width - new_width) // 2
|
102 |
+
paste_y = (target_height - new_height) // 2
|
103 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
104 |
+
else:
|
105 |
+
# not maintaining aspect ratio
|
106 |
+
new_image = image.resize((target_width, target_height))
|
107 |
+
|
108 |
+
return new_image
|
109 |
+
|
110 |
+
|
111 |
+
def divide_to_patches(image, patch_size):
|
112 |
+
"""
|
113 |
+
Divides an image into patches of a specified size.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
image (PIL.Image.Image): The input image.
|
117 |
+
patch_size (int): The size of each patch.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
121 |
+
"""
|
122 |
+
patches = []
|
123 |
+
width, height = image.size
|
124 |
+
for i in range(0, height, patch_size):
|
125 |
+
for j in range(0, width, patch_size):
|
126 |
+
box = (j, i, j + patch_size, i + patch_size)
|
127 |
+
patch = image.crop(box)
|
128 |
+
patches.append(patch)
|
129 |
+
|
130 |
+
return patches
|
131 |
+
|
132 |
+
|
133 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
134 |
+
"""
|
135 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
139 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
140 |
+
patch_size (int): The size of each image patch.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
144 |
+
"""
|
145 |
+
if type(grid_pinpoints) is list:
|
146 |
+
possible_resolutions = grid_pinpoints
|
147 |
+
else:
|
148 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
149 |
+
width1, height1 = select_best_resolution(image_size, possible_resolutions)
|
150 |
+
width2, height2 = select_best_resolution_v2(image_size, possible_resolutions)
|
151 |
+
if width1*height1 > width2*height2:
|
152 |
+
width, height = width2, height2
|
153 |
+
else:
|
154 |
+
width, height = width1, height1
|
155 |
+
return width // patch_size, height // patch_size
|
156 |
+
|
157 |
+
|
158 |
+
def process_anyres_image(image, image_transform, grid_pinpoints, base_image_size):
|
159 |
+
"""
|
160 |
+
Process an image with variable resolutions.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
image (PIL.Image.Image): The input image to be processed.
|
164 |
+
image_transform: The image processor object.
|
165 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
torch.Tensor: A tensor containing the processed image patches.
|
169 |
+
"""
|
170 |
+
if type(grid_pinpoints) is list:
|
171 |
+
possible_resolutions = grid_pinpoints
|
172 |
+
else:
|
173 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
174 |
+
# best_resolution = select_best_resolution(image.size, possible_resolutions)
|
175 |
+
width1, height1 = select_best_resolution(image.size, possible_resolutions)
|
176 |
+
width2, height2 = select_best_resolution_v2(image.size, possible_resolutions)
|
177 |
+
if width1*height1 > width2*height2:
|
178 |
+
width, height = width2, height2
|
179 |
+
else:
|
180 |
+
width, height = width1, height1
|
181 |
+
best_resolution = [width, height]
|
182 |
+
|
183 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
184 |
+
|
185 |
+
patches = divide_to_patches(image_padded, base_image_size)
|
186 |
+
|
187 |
+
image_original_resize = image.resize((base_image_size, base_image_size))
|
188 |
+
|
189 |
+
image_patches = patches + [image_original_resize] # add the original image as the last patch
|
190 |
+
image_patches = [image_transform(image_patch)
|
191 |
+
for image_patch in image_patches]
|
192 |
+
|
193 |
+
patch_grid = (best_resolution[0]//base_image_size, best_resolution[1]//base_image_size)
|
194 |
+
x_index = (torch.arange(patch_grid[0]).repeat(patch_grid[1], 1) + 0.5)/patch_grid[0]
|
195 |
+
y_index = (torch.arange(patch_grid[1]).unsqueeze(1).repeat(1, patch_grid[0]) + 0.5)/patch_grid[1]
|
196 |
+
patch_pos = torch.stack([x_index, y_index], dim=-1).flatten(0, 1) # h*w, 2
|
197 |
+
|
198 |
+
origin_pos = torch.tensor([[0.5, 0.5]])
|
199 |
+
patch_pos = torch.cat([patch_pos, origin_pos], dim=0) # h*w+1, 2
|
200 |
+
|
201 |
+
return torch.stack(image_patches, dim=0), patch_pos
|
202 |
+
|
203 |
+
|
204 |
+
def load_image_from_base64(image):
|
205 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
206 |
+
|
207 |
+
|
208 |
+
def anyres_data_collate(batch, tokenizer, dataset_name=None):
|
209 |
+
results = {}
|
210 |
+
keys = batch[0].keys()
|
211 |
+
|
212 |
+
for key in keys:
|
213 |
+
cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
|
214 |
+
if len(cur) == 0:
|
215 |
+
results[key] = None
|
216 |
+
elif isinstance(cur[0], torch.Tensor):
|
217 |
+
if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
|
218 |
+
results[key] = torch.cat(cur, dim=0)
|
219 |
+
else:
|
220 |
+
if key in ['input_ids']:
|
221 |
+
results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=tokenizer.pad_token_id)
|
222 |
+
elif key in ['attention_mask']:
|
223 |
+
results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=0)
|
224 |
+
elif key in ['labels']:
|
225 |
+
results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=-100)
|
226 |
+
elif key in ['ids_gen_mask', 'ids_cmp_mask']:
|
227 |
+
results[key] = torch.nn.utils.rnn.pad_sequence(cur, batch_first=True, padding_value=False)
|
228 |
+
|
229 |
+
else:
|
230 |
+
results[key] = torch.stack(cur, dim=0)
|
231 |
+
else:
|
232 |
+
results[key] = cur
|
233 |
+
|
234 |
+
results['dataset_name'] = dataset_name
|
235 |
+
|
236 |
+
return results
|
237 |
+
|
238 |
+
|
239 |
+
def anyres_data_collate_old(batch, dataset_name=None):
|
240 |
+
results = {}
|
241 |
+
keys = batch[0].keys()
|
242 |
+
|
243 |
+
for key in keys:
|
244 |
+
cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
|
245 |
+
if len(cur) == 0:
|
246 |
+
results[key] = None
|
247 |
+
elif isinstance(cur[0], torch.Tensor):
|
248 |
+
if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images', 'images_patch_length', 'patch_position', 'image_size']:
|
249 |
+
results[key] = torch.cat(cur, dim=0)
|
250 |
+
else:
|
251 |
+
results[key] = torch.stack(cur, dim=0)
|
252 |
+
else:
|
253 |
+
results[key] = cur
|
254 |
+
|
255 |
+
results['dataset_name'] = dataset_name
|
256 |
+
|
257 |
+
return results
|
src/inference/eval_img2edit_seed_x.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import pyrootutils
|
6 |
+
from PIL import Image
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, Transformer2DModel
|
9 |
+
from any_res import process_anyres_image
|
10 |
+
|
11 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
12 |
+
|
13 |
+
BOI_TOKEN = '<img>'
|
14 |
+
BOP_TOKEN = '<patch>'
|
15 |
+
EOI_TOKEN = '</img>'
|
16 |
+
EOP_TOKEN = '</patch>'
|
17 |
+
IMG_TOKEN = '<img_{:05d}>'
|
18 |
+
|
19 |
+
resolution_grids = ['1x1']
|
20 |
+
base_resolution = 448
|
21 |
+
|
22 |
+
device = 'cuda:0'
|
23 |
+
device1 = 'cuda:1'
|
24 |
+
dtype = torch.float16
|
25 |
+
dtype_str = 'fp16'
|
26 |
+
num_img_in_tokens = 64
|
27 |
+
num_img_out_tokens = 64
|
28 |
+
instruction_prompt = '[INST] {instruction} [/INST]\n'
|
29 |
+
|
30 |
+
save_dir = 'vis'
|
31 |
+
os.makedirs(save_dir, exist_ok=True)
|
32 |
+
|
33 |
+
tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
|
34 |
+
image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
|
35 |
+
visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
|
36 |
+
llm_cfg_path = 'configs/clm_models/llm_seed_x_edit.yaml'
|
37 |
+
agent_cfg_path = 'configs/clm_models/agent_seed_x_edit.yaml'
|
38 |
+
adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_full_with_latent_image_pretrain_no_normalize.yaml'
|
39 |
+
discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
|
40 |
+
|
41 |
+
diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
|
42 |
+
|
43 |
+
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
|
44 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
45 |
+
|
46 |
+
image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
|
47 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
48 |
+
|
49 |
+
visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
|
50 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
51 |
+
visual_encoder.eval().to(device1, dtype=dtype)
|
52 |
+
print('Init visual encoder done')
|
53 |
+
|
54 |
+
llm_cfg = OmegaConf.load(llm_cfg_path)
|
55 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
|
56 |
+
print('Init llm done.')
|
57 |
+
|
58 |
+
agent_model_cfg = OmegaConf.load(agent_cfg_path)
|
59 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
|
60 |
+
|
61 |
+
agent_model.eval().to(device, dtype=dtype)
|
62 |
+
print('Init agent mdoel Done')
|
63 |
+
|
64 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
|
65 |
+
print('init vae')
|
66 |
+
vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype)
|
67 |
+
print('init unet')
|
68 |
+
unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype)
|
69 |
+
|
70 |
+
adapter_cfg = OmegaConf.load(adapter_cfg_path)
|
71 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval()
|
72 |
+
|
73 |
+
discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
|
74 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval()
|
75 |
+
print('Init adapter done')
|
76 |
+
|
77 |
+
adapter.init_pipe(vae=vae,
|
78 |
+
scheduler=noise_scheduler,
|
79 |
+
visual_encoder=visual_encoder,
|
80 |
+
image_transform=image_transform,
|
81 |
+
dtype=dtype,
|
82 |
+
device=device1)
|
83 |
+
|
84 |
+
print('Init adapter pipe done')
|
85 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
86 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
87 |
+
|
88 |
+
bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
|
89 |
+
eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
|
90 |
+
|
91 |
+
grid_pinpoints = []
|
92 |
+
for scale in resolution_grids:
|
93 |
+
s1, s2 = scale.split('x')
|
94 |
+
grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
|
95 |
+
grid_pinpoints = grid_pinpoints
|
96 |
+
|
97 |
+
|
98 |
+
image_path = 'demo_images/car.jpg'
|
99 |
+
instruction = 'Make it under the sunset'
|
100 |
+
|
101 |
+
image = Image.open(image_path).convert('RGB')
|
102 |
+
source_image = image.resize((1024, 1024))
|
103 |
+
|
104 |
+
image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
|
105 |
+
embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
|
106 |
+
|
107 |
+
patch_pos = [patch_pos_tensor]
|
108 |
+
patch_position = torch.cat(patch_pos, dim=0)
|
109 |
+
|
110 |
+
image_tensor = image_tensor.to(device1, dtype=dtype)
|
111 |
+
|
112 |
+
patch_length = image_tensor.shape[0]
|
113 |
+
image_tokens = ''
|
114 |
+
for _ in range(patch_length-1):
|
115 |
+
image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
|
116 |
+
image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
|
117 |
+
|
118 |
+
prompt = instruction_prompt.format_map({'instruction': image_tokens + instruction})
|
119 |
+
|
120 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
121 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
122 |
+
|
123 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
|
124 |
+
|
125 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
126 |
+
|
127 |
+
boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
|
128 |
+
eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
|
129 |
+
|
130 |
+
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
|
131 |
+
ids_cmp_mask[boi_idx + 1:eoi_idx] = True
|
132 |
+
|
133 |
+
input_ids = input_ids.unsqueeze(0)
|
134 |
+
ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
image_embeds = visual_encoder(image_tensor)
|
138 |
+
image_embeds = image_embeds.to(device)
|
139 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
140 |
+
input_ids=input_ids,
|
141 |
+
image_embeds=image_embeds,
|
142 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
143 |
+
patch_positions=patch_position,
|
144 |
+
ids_cmp_mask=ids_cmp_mask,
|
145 |
+
max_new_tokens=512,
|
146 |
+
num_img_gen_tokens=num_img_out_tokens)
|
147 |
+
text = re.sub('<[^>]*>', '', output['text'])
|
148 |
+
print(text)
|
149 |
+
|
150 |
+
if output['has_img_output']:
|
151 |
+
images = adapter.generate(image_embeds=output['img_gen_feat'].to(device1), latent_image=source_image, num_inference_steps=50)
|
152 |
+
|
153 |
+
save_path = os.path.join(save_dir, str(len(os.listdir(save_dir))) + '_' + instruction + '.jpg')
|
154 |
+
images[0].save(save_path)
|
155 |
+
torch.cuda.empty_cache()
|
src/inference/eval_img2text_seed_x.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import pyrootutils
|
5 |
+
from PIL import Image
|
6 |
+
import re
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
|
11 |
+
from any_res import process_anyres_image
|
12 |
+
|
13 |
+
|
14 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
15 |
+
|
16 |
+
def visualize_bbox(image, bboxes, save_path):
|
17 |
+
img_width, img_height = image.size
|
18 |
+
image = np.array(image)
|
19 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
20 |
+
for bbox in bboxes:
|
21 |
+
x_center, y_center, box_width, box_height = bbox
|
22 |
+
|
23 |
+
x_center = x_center / 224 * img_width
|
24 |
+
y_center = y_center / 224 * img_height
|
25 |
+
|
26 |
+
box_width = box_width /224 * img_width
|
27 |
+
box_height = box_height / 224 * img_height
|
28 |
+
|
29 |
+
x1 = int(x_center - box_width / 2)
|
30 |
+
y1 = int(y_center - box_height / 2)
|
31 |
+
x2 = int(x_center + box_width / 2)
|
32 |
+
y2 = int(y_center + box_height / 2)
|
33 |
+
|
34 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
35 |
+
|
36 |
+
cv2.imwrite(save_path, image)
|
37 |
+
|
38 |
+
|
39 |
+
def extract_box(output_str):
|
40 |
+
boxes = re.findall('<box_start>(.*?)<box_end>', output_str)
|
41 |
+
if len(boxes) >0:
|
42 |
+
bboxes = [[int(num) for num in re.findall('<loc-(\d+)>', box)] for box in boxes]
|
43 |
+
else:
|
44 |
+
bboxes = None
|
45 |
+
|
46 |
+
return bboxes
|
47 |
+
|
48 |
+
|
49 |
+
BOI_TOKEN = '<img>'
|
50 |
+
BOP_TOKEN = '<patch>'
|
51 |
+
EOI_TOKEN = '</img>'
|
52 |
+
EOP_TOKEN = '</patch>'
|
53 |
+
IMG_TOKEN = '<img_{:05d}>'
|
54 |
+
|
55 |
+
instruction_prompt = '[INST] {instruction} [/INST]\n'
|
56 |
+
|
57 |
+
resolution_grids = ['1x1', '1x2', '1x3', '2x1', '3x1', '1x4', '4x1', '2x2']
|
58 |
+
base_resolution = 448
|
59 |
+
|
60 |
+
device = 'cuda:0'
|
61 |
+
device1 = 'cuda:1'
|
62 |
+
dtype = torch.float16
|
63 |
+
dtype_str = 'fp16'
|
64 |
+
num_img_in_tokens = 64
|
65 |
+
num_img_out_tokens = 64
|
66 |
+
|
67 |
+
tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
|
68 |
+
image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
|
69 |
+
visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
|
70 |
+
llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml'
|
71 |
+
agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml'
|
72 |
+
adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml'
|
73 |
+
discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
|
74 |
+
|
75 |
+
diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
|
76 |
+
|
77 |
+
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
|
78 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
79 |
+
|
80 |
+
image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
|
81 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
82 |
+
|
83 |
+
visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
|
84 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
85 |
+
visual_encoder.eval().to(device1, dtype=dtype)
|
86 |
+
print('Init visual encoder done')
|
87 |
+
|
88 |
+
llm_cfg = OmegaConf.load(llm_cfg_path)
|
89 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
|
90 |
+
print('Init llm done.')
|
91 |
+
|
92 |
+
agent_model_cfg = OmegaConf.load(agent_cfg_path)
|
93 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
|
94 |
+
|
95 |
+
agent_model.eval().to(device, dtype=dtype)
|
96 |
+
print('Init agent mdoel Done')
|
97 |
+
|
98 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
|
99 |
+
print('init vae')
|
100 |
+
vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device1, dtype=dtype)
|
101 |
+
print('init unet')
|
102 |
+
unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device1, dtype=dtype)
|
103 |
+
|
104 |
+
adapter_cfg = OmegaConf.load(adapter_cfg_path)
|
105 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device1, dtype=dtype).eval()
|
106 |
+
|
107 |
+
discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
|
108 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device1).eval()
|
109 |
+
print('Init adapter done')
|
110 |
+
|
111 |
+
adapter.init_pipe(vae=vae,
|
112 |
+
scheduler=noise_scheduler,
|
113 |
+
visual_encoder=visual_encoder,
|
114 |
+
image_transform=image_transform,
|
115 |
+
discrete_model=discrete_model,
|
116 |
+
dtype=dtype,
|
117 |
+
device=device1)
|
118 |
+
|
119 |
+
print('Init adapter pipe done')
|
120 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
121 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
122 |
+
|
123 |
+
bop_token_id = tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0]
|
124 |
+
eop_token_id = tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0]
|
125 |
+
|
126 |
+
grid_pinpoints = []
|
127 |
+
for scale in resolution_grids:
|
128 |
+
s1, s2 = scale.split('x')
|
129 |
+
grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution])
|
130 |
+
grid_pinpoints = grid_pinpoints
|
131 |
+
|
132 |
+
# image comprehension
|
133 |
+
image_path = 'demo_images/advisor.png'
|
134 |
+
image = Image.open(image_path).convert('RGB')
|
135 |
+
image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
|
136 |
+
embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
|
137 |
+
|
138 |
+
patch_pos = [patch_pos_tensor]
|
139 |
+
patch_position = torch.cat(patch_pos, dim=0)
|
140 |
+
|
141 |
+
image_tensor = image_tensor.to(device1, dtype=dtype)
|
142 |
+
|
143 |
+
patch_length = image_tensor.shape[0]
|
144 |
+
image_tokens = ''
|
145 |
+
for _ in range(patch_length-1):
|
146 |
+
image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
|
147 |
+
image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
|
148 |
+
|
149 |
+
question = 'Can I conntect with an advisor on Sunday?'
|
150 |
+
prompt = instruction_prompt.format_map({'instruction': image_tokens + question})
|
151 |
+
|
152 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
153 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
154 |
+
|
155 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
|
156 |
+
|
157 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
158 |
+
|
159 |
+
boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
|
160 |
+
eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
|
161 |
+
|
162 |
+
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
|
163 |
+
ids_cmp_mask[boi_idx + 1:eoi_idx] = True
|
164 |
+
|
165 |
+
input_ids = input_ids.unsqueeze(0)
|
166 |
+
ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
image_embeds = visual_encoder(image_tensor)
|
170 |
+
image_embeds = image_embeds.to(device)
|
171 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
172 |
+
input_ids=input_ids,
|
173 |
+
image_embeds=image_embeds,
|
174 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
175 |
+
patch_positions=patch_position,
|
176 |
+
ids_cmp_mask=ids_cmp_mask,
|
177 |
+
max_new_tokens=512,
|
178 |
+
num_img_gen_tokens=num_img_out_tokens)
|
179 |
+
|
180 |
+
text = re.sub('<[^>]*>', '', output['text'])
|
181 |
+
print(text)
|
182 |
+
|
183 |
+
# detection
|
184 |
+
image_path = 'demo_images/ground.png'
|
185 |
+
image = Image.open(image_path).convert('RGB')
|
186 |
+
image_tensor, patch_pos_tensor = process_anyres_image(image, image_transform, grid_pinpoints, base_resolution)
|
187 |
+
embeds_cmp_mask = torch.tensor([True]*image_tensor.shape[0]).to(device, dtype=torch.bool)
|
188 |
+
|
189 |
+
patch_pos = [patch_pos_tensor]
|
190 |
+
patch_position = torch.cat(patch_pos, dim=0)
|
191 |
+
|
192 |
+
image_tensor = image_tensor.to(device1, dtype=dtype)
|
193 |
+
|
194 |
+
patch_length = image_tensor.shape[0]
|
195 |
+
image_tokens = ''
|
196 |
+
for _ in range(patch_length-1):
|
197 |
+
image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
|
198 |
+
image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
|
199 |
+
|
200 |
+
question = 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'
|
201 |
+
prompt = instruction_prompt.format_map({'instruction': image_tokens + question})
|
202 |
+
|
203 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
204 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
205 |
+
|
206 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long)
|
207 |
+
|
208 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
209 |
+
|
210 |
+
boi_indices = torch.where(torch.logical_or(input_ids == boi_token_id, input_ids == bop_token_id))[0].tolist()
|
211 |
+
eoi_indices = torch.where(torch.logical_or(input_ids == eoi_token_id, input_ids == eop_token_id))[0].tolist()
|
212 |
+
|
213 |
+
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
|
214 |
+
ids_cmp_mask[boi_idx + 1:eoi_idx] = True
|
215 |
+
|
216 |
+
input_ids = input_ids.unsqueeze(0)
|
217 |
+
ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
|
218 |
+
|
219 |
+
with torch.no_grad():
|
220 |
+
image_embeds = visual_encoder(image_tensor)
|
221 |
+
image_embeds = image_embeds.to(device)
|
222 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
223 |
+
input_ids=input_ids,
|
224 |
+
image_embeds=image_embeds,
|
225 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
226 |
+
patch_positions=patch_position,
|
227 |
+
ids_cmp_mask=ids_cmp_mask,
|
228 |
+
max_new_tokens=512,
|
229 |
+
num_img_gen_tokens=num_img_out_tokens)
|
230 |
+
print(output['text'])
|
231 |
+
bbox = extract_box(output['text'])
|
232 |
+
if bbox is not None:
|
233 |
+
save_path = 'vis/ground.png'
|
234 |
+
visualize_bbox(image, bbox, save_path)
|
235 |
+
|
src/inference/eval_text2img_seed_x.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import pyrootutils
|
5 |
+
from PIL import Image
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler
|
8 |
+
|
9 |
+
|
10 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
11 |
+
|
12 |
+
BOI_TOKEN = '<img>'
|
13 |
+
EOI_TOKEN = '</img>'
|
14 |
+
IMG_TOKEN = '<img_{:05d}>'
|
15 |
+
|
16 |
+
device = 'cuda:0'
|
17 |
+
device_2 = 'cuda:1'
|
18 |
+
dtype = torch.float16
|
19 |
+
dtype_str = 'fp16'
|
20 |
+
num_img_in_tokens = 64
|
21 |
+
num_img_out_tokens = 64
|
22 |
+
|
23 |
+
instruction_prompt = '[INST] Generate an image: {caption} [/INST]\n'
|
24 |
+
|
25 |
+
tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer_224loc_anyres.yaml'
|
26 |
+
image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
|
27 |
+
visual_encoder_cfg_path = 'configs/visual_encoder/qwen_vitg_448.yaml'
|
28 |
+
llm_cfg_path = 'configs/clm_models/llm_seed_x_i.yaml'
|
29 |
+
agent_cfg_path = 'configs/clm_models/agent_seed_x_i.yaml'
|
30 |
+
adapter_cfg_path = 'configs/sdxl_adapter/sdxl_qwen_vit_resampler_l4_q64_pretrain_no_normalize.yaml'
|
31 |
+
discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
|
32 |
+
|
33 |
+
diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
|
34 |
+
|
35 |
+
save_dir = 'vis'
|
36 |
+
os.makedirs(save_dir, exist_ok=True)
|
37 |
+
|
38 |
+
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
|
39 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
40 |
+
|
41 |
+
image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
|
42 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
43 |
+
|
44 |
+
visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
|
45 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
46 |
+
visual_encoder.eval().to(device_2, dtype=dtype)
|
47 |
+
print('Init visual encoder done')
|
48 |
+
|
49 |
+
llm_cfg = OmegaConf.load(llm_cfg_path)
|
50 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype)
|
51 |
+
print('Init llm done.')
|
52 |
+
|
53 |
+
agent_model_cfg = OmegaConf.load(agent_cfg_path)
|
54 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
|
55 |
+
|
56 |
+
agent_model.eval().to(device, dtype=dtype)
|
57 |
+
print('Init agent mdoel Done')
|
58 |
+
|
59 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
|
60 |
+
print('init vae')
|
61 |
+
vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device_2, dtype=dtype)
|
62 |
+
print('init unet')
|
63 |
+
unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device_2, dtype=dtype)
|
64 |
+
|
65 |
+
adapter_cfg = OmegaConf.load(adapter_cfg_path)
|
66 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device_2, dtype=dtype).eval()
|
67 |
+
|
68 |
+
discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
|
69 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device_2).eval()
|
70 |
+
print('Init adapter done')
|
71 |
+
|
72 |
+
adapter.init_pipe(vae=vae,
|
73 |
+
scheduler=noise_scheduler,
|
74 |
+
visual_encoder=visual_encoder,
|
75 |
+
image_transform=image_transform,
|
76 |
+
discrete_model=discrete_model,
|
77 |
+
dtype=dtype,
|
78 |
+
device=device_2)
|
79 |
+
|
80 |
+
print('Init adapter pipe done')
|
81 |
+
|
82 |
+
caption = 'A cybernetic soldier, enhanced with advanced weapons systems and tactical analysis software, on a mission behind enemy lines.'
|
83 |
+
prompt = instruction_prompt.format_map({'caption': caption})
|
84 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
85 |
+
input_ids = torch.tensor([tokenizer.bos_token_id] + prompt_ids).to(device, dtype=torch.long).unsqueeze(0)
|
86 |
+
output = agent_model.generate(tokenizer=tokenizer, input_ids=input_ids, num_img_gen_tokens=num_img_out_tokens)
|
87 |
+
print(output['has_img_output'])
|
88 |
+
print(output['text'])
|
89 |
+
|
90 |
+
if output['has_img_output']:
|
91 |
+
images = adapter.generate(image_embeds=output['img_gen_feat'].to(device_2), num_inference_steps=50)
|
92 |
+
save_path = os.path.join(save_dir, caption.replace('.', '') + '.png')
|
93 |
+
images[0].save(save_path)
|
94 |
+
torch.cuda.empty_cache()
|
src/models/detokenizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/models/detokenizer/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (182 Bytes). View file
|
|
src/models/detokenizer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (175 Bytes). View file
|
|
src/models/detokenizer/__pycache__/adapter_modules.cpython-311.pyc
ADDED
Binary file (14 kB). View file
|
|
src/models/detokenizer/__pycache__/adapter_modules.cpython-38.pyc
ADDED
Binary file (7.31 kB). View file
|
|
src/models/detokenizer/__pycache__/attention_processor.cpython-38.pyc
ADDED
Binary file (7.4 kB). View file
|
|
src/models/detokenizer/__pycache__/ipa_utils.cpython-38.pyc
ADDED
Binary file (397 Bytes). View file
|
|
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_t2i_edit.cpython-38.pyc
ADDED
Binary file (28.3 kB). View file
|
|
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-311.pyc
ADDED
Binary file (53 kB). View file
|
|
src/models/detokenizer/__pycache__/pipeline_stable_diffusion_xl_t2i_edit.cpython-38.pyc
ADDED
Binary file (36.8 kB). View file
|
|
src/models/detokenizer/__pycache__/resampler.cpython-311.pyc
ADDED
Binary file (16.2 kB). View file
|
|