ohashi56225
commited on
Commit
•
20150e3
1
Parent(s):
033911f
Upload LlavaForConditionalGeneration
Browse files- config.json +242 -0
- configuration_llava.py +131 -0
- generation_config.json +6 -0
- model.safetensors +3 -0
- modeling_llava.py +345 -0
config.json
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "output/jp-llava-small-stair-bs128-lr5e5/checkpoints",
|
3 |
+
"architectures": [
|
4 |
+
"LlavaForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_llava.LlavaConfig",
|
8 |
+
"AutoModelForVision2Seq": "modeling_llava.LlavaForConditionalGeneration"
|
9 |
+
},
|
10 |
+
"initializer_factor": 1.0,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"mlp_config": {
|
13 |
+
"_name_or_path": "",
|
14 |
+
"add_cross_attention": false,
|
15 |
+
"architectures": null,
|
16 |
+
"bad_words_ids": null,
|
17 |
+
"begin_suppress_tokens": null,
|
18 |
+
"bos_token_id": null,
|
19 |
+
"chunk_size_feed_forward": 0,
|
20 |
+
"cross_attention_hidden_size": null,
|
21 |
+
"decoder_start_token_id": null,
|
22 |
+
"diversity_penalty": 0.0,
|
23 |
+
"do_sample": false,
|
24 |
+
"early_stopping": false,
|
25 |
+
"encoder_no_repeat_ngram_size": 0,
|
26 |
+
"eos_token_id": null,
|
27 |
+
"exponential_decay_length_penalty": null,
|
28 |
+
"finetuning_task": null,
|
29 |
+
"forced_bos_token_id": null,
|
30 |
+
"forced_eos_token_id": null,
|
31 |
+
"id2label": {
|
32 |
+
"0": "LABEL_0",
|
33 |
+
"1": "LABEL_1"
|
34 |
+
},
|
35 |
+
"is_decoder": false,
|
36 |
+
"is_encoder_decoder": false,
|
37 |
+
"label2id": {
|
38 |
+
"LABEL_0": 0,
|
39 |
+
"LABEL_1": 1
|
40 |
+
},
|
41 |
+
"length_penalty": 1.0,
|
42 |
+
"max_length": 20,
|
43 |
+
"min_length": 0,
|
44 |
+
"model_type": "llava_mlp",
|
45 |
+
"no_repeat_ngram_size": 0,
|
46 |
+
"num_beam_groups": 1,
|
47 |
+
"num_beams": 1,
|
48 |
+
"num_hidden_layers": 2,
|
49 |
+
"num_return_sequences": 1,
|
50 |
+
"output_attentions": false,
|
51 |
+
"output_hidden_states": false,
|
52 |
+
"output_scores": false,
|
53 |
+
"pad_token_id": null,
|
54 |
+
"prefix": null,
|
55 |
+
"problem_type": null,
|
56 |
+
"pruned_heads": {},
|
57 |
+
"remove_invalid_values": false,
|
58 |
+
"repetition_penalty": 1.0,
|
59 |
+
"return_dict": true,
|
60 |
+
"return_dict_in_generate": false,
|
61 |
+
"sep_token_id": null,
|
62 |
+
"suppress_tokens": null,
|
63 |
+
"task_specific_params": null,
|
64 |
+
"temperature": 1.0,
|
65 |
+
"tf_legacy_loss": false,
|
66 |
+
"tie_encoder_decoder": false,
|
67 |
+
"tie_word_embeddings": true,
|
68 |
+
"tokenizer_class": null,
|
69 |
+
"top_k": 50,
|
70 |
+
"top_p": 1.0,
|
71 |
+
"torch_dtype": null,
|
72 |
+
"torchscript": false,
|
73 |
+
"typical_p": 1.0,
|
74 |
+
"use_bfloat16": false
|
75 |
+
},
|
76 |
+
"model_type": "llava",
|
77 |
+
"text_config": {
|
78 |
+
"_name_or_path": "rinna/japanese-gpt-neox-small",
|
79 |
+
"add_cross_attention": false,
|
80 |
+
"architectures": [
|
81 |
+
"GPTNeoXForCausalLM"
|
82 |
+
],
|
83 |
+
"attention_dropout": 0.0,
|
84 |
+
"bad_words_ids": null,
|
85 |
+
"begin_suppress_tokens": null,
|
86 |
+
"bos_token_id": 2,
|
87 |
+
"chunk_size_feed_forward": 0,
|
88 |
+
"classifier_dropout": 0.1,
|
89 |
+
"cross_attention_hidden_size": null,
|
90 |
+
"decoder_start_token_id": null,
|
91 |
+
"diversity_penalty": 0.0,
|
92 |
+
"do_sample": false,
|
93 |
+
"early_stopping": false,
|
94 |
+
"encoder_no_repeat_ngram_size": 0,
|
95 |
+
"eos_token_id": 3,
|
96 |
+
"exponential_decay_length_penalty": null,
|
97 |
+
"finetuning_task": null,
|
98 |
+
"forced_bos_token_id": null,
|
99 |
+
"forced_eos_token_id": null,
|
100 |
+
"hidden_act": "gelu",
|
101 |
+
"hidden_dropout": 0.0,
|
102 |
+
"hidden_size": 768,
|
103 |
+
"id2label": {
|
104 |
+
"0": "LABEL_0",
|
105 |
+
"1": "LABEL_1"
|
106 |
+
},
|
107 |
+
"initializer_range": 0.02,
|
108 |
+
"intermediate_size": 3072,
|
109 |
+
"is_decoder": false,
|
110 |
+
"is_encoder_decoder": false,
|
111 |
+
"label2id": {
|
112 |
+
"LABEL_0": 0,
|
113 |
+
"LABEL_1": 1
|
114 |
+
},
|
115 |
+
"layer_norm_eps": 1e-05,
|
116 |
+
"length_penalty": 1.0,
|
117 |
+
"max_length": 20,
|
118 |
+
"max_position_embeddings": 2048,
|
119 |
+
"min_length": 0,
|
120 |
+
"model_type": "gpt_neox",
|
121 |
+
"no_repeat_ngram_size": 0,
|
122 |
+
"num_attention_heads": 12,
|
123 |
+
"num_beam_groups": 1,
|
124 |
+
"num_beams": 1,
|
125 |
+
"num_hidden_layers": 12,
|
126 |
+
"num_return_sequences": 1,
|
127 |
+
"output_attentions": false,
|
128 |
+
"output_hidden_states": false,
|
129 |
+
"output_scores": false,
|
130 |
+
"pad_token_id": null,
|
131 |
+
"prefix": null,
|
132 |
+
"problem_type": null,
|
133 |
+
"pruned_heads": {},
|
134 |
+
"remove_invalid_values": false,
|
135 |
+
"repetition_penalty": 1.0,
|
136 |
+
"return_dict": true,
|
137 |
+
"return_dict_in_generate": false,
|
138 |
+
"rope_scaling": null,
|
139 |
+
"rotary_emb_base": 10000,
|
140 |
+
"rotary_pct": 1.0,
|
141 |
+
"sep_token_id": null,
|
142 |
+
"suppress_tokens": null,
|
143 |
+
"task_specific_params": null,
|
144 |
+
"temperature": 1.0,
|
145 |
+
"tf_legacy_loss": false,
|
146 |
+
"tie_encoder_decoder": false,
|
147 |
+
"tie_word_embeddings": false,
|
148 |
+
"tokenizer_class": "T5Tokenizer",
|
149 |
+
"top_k": 50,
|
150 |
+
"top_p": 1.0,
|
151 |
+
"torch_dtype": "float32",
|
152 |
+
"torchscript": false,
|
153 |
+
"typical_p": 1.0,
|
154 |
+
"use_bfloat16": false,
|
155 |
+
"use_cache": true,
|
156 |
+
"use_parallel_residual": false,
|
157 |
+
"vocab_size": 44416
|
158 |
+
},
|
159 |
+
"tie_word_embeddings": false,
|
160 |
+
"torch_dtype": "float32",
|
161 |
+
"transformers_version": "4.35.2",
|
162 |
+
"use_decoder_only_language_model": true,
|
163 |
+
"vision_config": {
|
164 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
165 |
+
"add_cross_attention": false,
|
166 |
+
"architectures": null,
|
167 |
+
"attention_dropout": 0.0,
|
168 |
+
"bad_words_ids": null,
|
169 |
+
"begin_suppress_tokens": null,
|
170 |
+
"bos_token_id": null,
|
171 |
+
"chunk_size_feed_forward": 0,
|
172 |
+
"cross_attention_hidden_size": null,
|
173 |
+
"decoder_start_token_id": null,
|
174 |
+
"diversity_penalty": 0.0,
|
175 |
+
"do_sample": false,
|
176 |
+
"dropout": 0.0,
|
177 |
+
"early_stopping": false,
|
178 |
+
"encoder_no_repeat_ngram_size": 0,
|
179 |
+
"eos_token_id": null,
|
180 |
+
"exponential_decay_length_penalty": null,
|
181 |
+
"finetuning_task": null,
|
182 |
+
"forced_bos_token_id": null,
|
183 |
+
"forced_eos_token_id": null,
|
184 |
+
"hidden_act": "quick_gelu",
|
185 |
+
"hidden_size": 1024,
|
186 |
+
"id2label": {
|
187 |
+
"0": "LABEL_0",
|
188 |
+
"1": "LABEL_1"
|
189 |
+
},
|
190 |
+
"image_size": 224,
|
191 |
+
"initializer_factor": 1.0,
|
192 |
+
"initializer_range": 0.02,
|
193 |
+
"intermediate_size": 4096,
|
194 |
+
"is_decoder": false,
|
195 |
+
"is_encoder_decoder": false,
|
196 |
+
"label2id": {
|
197 |
+
"LABEL_0": 0,
|
198 |
+
"LABEL_1": 1
|
199 |
+
},
|
200 |
+
"layer_norm_eps": 1e-05,
|
201 |
+
"length_penalty": 1.0,
|
202 |
+
"max_length": 20,
|
203 |
+
"min_length": 0,
|
204 |
+
"model_type": "clip_vision_model",
|
205 |
+
"no_repeat_ngram_size": 0,
|
206 |
+
"num_attention_heads": 16,
|
207 |
+
"num_beam_groups": 1,
|
208 |
+
"num_beams": 1,
|
209 |
+
"num_channels": 3,
|
210 |
+
"num_hidden_layers": 24,
|
211 |
+
"num_return_sequences": 1,
|
212 |
+
"output_attentions": false,
|
213 |
+
"output_hidden_states": false,
|
214 |
+
"output_scores": false,
|
215 |
+
"pad_token_id": null,
|
216 |
+
"patch_size": 14,
|
217 |
+
"prefix": null,
|
218 |
+
"problem_type": null,
|
219 |
+
"projection_dim": 768,
|
220 |
+
"pruned_heads": {},
|
221 |
+
"remove_invalid_values": false,
|
222 |
+
"repetition_penalty": 1.0,
|
223 |
+
"return_dict": true,
|
224 |
+
"return_dict_in_generate": false,
|
225 |
+
"sep_token_id": null,
|
226 |
+
"suppress_tokens": null,
|
227 |
+
"task_specific_params": null,
|
228 |
+
"temperature": 1.0,
|
229 |
+
"tf_legacy_loss": false,
|
230 |
+
"tie_encoder_decoder": false,
|
231 |
+
"tie_word_embeddings": true,
|
232 |
+
"tokenizer_class": null,
|
233 |
+
"top_k": 50,
|
234 |
+
"top_p": 1.0,
|
235 |
+
"torch_dtype": null,
|
236 |
+
"torchscript": false,
|
237 |
+
"typical_p": 1.0,
|
238 |
+
"use_bfloat16": false
|
239 |
+
},
|
240 |
+
"vision_select_feature": "patch",
|
241 |
+
"vision_select_layer": -2
|
242 |
+
}
|
configuration_llava.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stability AI team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from typing import Union
|
16 |
+
|
17 |
+
from transformers import PretrainedConfig, CLIPVisionConfig
|
18 |
+
from transformers.models.auto import CONFIG_MAPPING
|
19 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class LlavaMlpConfig(PretrainedConfig):
|
27 |
+
model_type = "llava_mlp"
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
num_hidden_layers=2,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
|
36 |
+
self.num_hidden_layers = num_hidden_layers
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_pretrained(
|
40 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
41 |
+
) -> "PretrainedConfig":
|
42 |
+
cls._set_token_in_kwargs(kwargs)
|
43 |
+
|
44 |
+
config_dict, kwargs = cls.get_config_dict(
|
45 |
+
pretrained_model_name_or_path, **kwargs
|
46 |
+
)
|
47 |
+
|
48 |
+
# get the qformer config dict if we are loading from InstructBlipConfig
|
49 |
+
if config_dict.get("model_type") == "llava":
|
50 |
+
config_dict = config_dict["mlp_config"]
|
51 |
+
|
52 |
+
if (
|
53 |
+
"model_type" in config_dict
|
54 |
+
and hasattr(cls, "model_type")
|
55 |
+
and config_dict["model_type"] != cls.model_type
|
56 |
+
):
|
57 |
+
logger.warning(
|
58 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
59 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
60 |
+
)
|
61 |
+
|
62 |
+
return cls.from_dict(config_dict, **kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
class LlavaConfig(PretrainedConfig):
|
66 |
+
model_type = "llava"
|
67 |
+
is_composition = True
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
vision_config=None,
|
72 |
+
mlp_config=None,
|
73 |
+
text_config=None,
|
74 |
+
vision_select_layer=-2,
|
75 |
+
vision_select_feature="patch",
|
76 |
+
**kwargs,
|
77 |
+
):
|
78 |
+
super().__init__(**kwargs)
|
79 |
+
|
80 |
+
if vision_config is None:
|
81 |
+
vision_config = {}
|
82 |
+
logger.info(
|
83 |
+
"vision_config is None. initializing the CLIPVisionConfig with default values."
|
84 |
+
)
|
85 |
+
|
86 |
+
if mlp_config is None:
|
87 |
+
mlp_config = {}
|
88 |
+
logger.info(
|
89 |
+
"mlp_config is None. Initializing the LlavaMlpConfig with default values."
|
90 |
+
)
|
91 |
+
|
92 |
+
if text_config is None:
|
93 |
+
text_config = {}
|
94 |
+
logger.info(
|
95 |
+
"text_config is None. Initializing the text config with default values (`OPTConfig`)."
|
96 |
+
)
|
97 |
+
|
98 |
+
self.vision_config = CLIPVisionConfig(**vision_config)
|
99 |
+
self.mlp_config = LlavaMlpConfig(**mlp_config)
|
100 |
+
text_model_type = text_config["model_type"]
|
101 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
102 |
+
|
103 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
104 |
+
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
105 |
+
|
106 |
+
self.use_decoder_only_language_model = (
|
107 |
+
self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
108 |
+
)
|
109 |
+
self.vision_select_layer = vision_select_layer
|
110 |
+
assert vision_select_feature in [
|
111 |
+
"cls_patch",
|
112 |
+
"patch",
|
113 |
+
], f"Unexpected select feature: {vision_select_feature}"
|
114 |
+
self.vision_select_feature = vision_select_feature
|
115 |
+
self.initializer_factor = 1.0
|
116 |
+
self.initializer_range = 0.02
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def from_vision_mlp_text_configs(
|
120 |
+
cls,
|
121 |
+
vision_config: CLIPVisionConfig,
|
122 |
+
mlp_config: LlavaMlpConfig,
|
123 |
+
text_config: PretrainedConfig,
|
124 |
+
**kwargs,
|
125 |
+
):
|
126 |
+
return cls(
|
127 |
+
vision_config=vision_config.to_dict(),
|
128 |
+
mlp_config=mlp_config.to_dict(),
|
129 |
+
text_config=text_config.to_dict(),
|
130 |
+
**kwargs,
|
131 |
+
)
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 2,
|
4 |
+
"eos_token_id": 3,
|
5 |
+
"transformers_version": "4.35.2"
|
6 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53edeacc890e71dc2a330ca346e026e6f98af20550f73390e2286d1cff262c6a
|
3 |
+
size 1831419000
|
modeling_llava.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stability AI team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union, Any
|
15 |
+
from dataclasses import dataclass
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch.nn import CrossEntropyLoss
|
19 |
+
|
20 |
+
from transformers import (
|
21 |
+
AutoModelForCausalLM,
|
22 |
+
AutoModelForSeq2SeqLM,
|
23 |
+
PreTrainedModel,
|
24 |
+
CLIPVisionModel,
|
25 |
+
)
|
26 |
+
|
27 |
+
from transformers.utils import logging, ModelOutput
|
28 |
+
from .configuration_llava import LlavaConfig
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class LlavaForConditionalGenerationModelOutput(ModelOutput):
|
36 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
37 |
+
logits: Optional[Tuple[torch.FloatTensor]] = None
|
38 |
+
vision_outputs: Optional[torch.FloatTensor] = None
|
39 |
+
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
40 |
+
|
41 |
+
def to_tuple(self) -> Tuple[Any]:
|
42 |
+
return tuple(
|
43 |
+
self[k]
|
44 |
+
if k not in ["vision_outputs", "language_model_outputs"]
|
45 |
+
else getattr(self, k).to_tuple()
|
46 |
+
for k in self.keys()
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class LlavaPreTrainedModel(PreTrainedModel):
|
51 |
+
"""
|
52 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
53 |
+
models.
|
54 |
+
"""
|
55 |
+
|
56 |
+
config_class = LlavaConfig
|
57 |
+
base_model_prefix = "llava"
|
58 |
+
|
59 |
+
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
|
60 |
+
def _init_weights(self, module):
|
61 |
+
"""Initialize the weights"""
|
62 |
+
factor = self.config.initializer_range
|
63 |
+
if (
|
64 |
+
isinstance(module, nn.Conv2d)
|
65 |
+
or isinstance(module, nn.Embedding)
|
66 |
+
or isinstance(module, nn.Linear)
|
67 |
+
):
|
68 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
69 |
+
if hasattr(module, "bias") and module.bias is not None:
|
70 |
+
module.bias.data.zero_()
|
71 |
+
|
72 |
+
elif isinstance(module, nn.LayerNorm):
|
73 |
+
module.bias.data.zero_()
|
74 |
+
module.weight.data.fill_(1.0)
|
75 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
76 |
+
module.bias.data.zero_()
|
77 |
+
|
78 |
+
|
79 |
+
class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
80 |
+
config_class = LlavaConfig
|
81 |
+
main_input_name = "pixel_values"
|
82 |
+
_no_split_modules = []
|
83 |
+
|
84 |
+
def __init__(self, config: LlavaConfig):
|
85 |
+
super().__init__(config)
|
86 |
+
|
87 |
+
self.vision_model = CLIPVisionModel(config.vision_config)
|
88 |
+
if config.use_decoder_only_language_model:
|
89 |
+
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
90 |
+
else:
|
91 |
+
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
92 |
+
|
93 |
+
if language_model._no_split_modules is not None:
|
94 |
+
self._no_split_modules.extend(language_model._no_split_modules)
|
95 |
+
|
96 |
+
if language_model._keep_in_fp32_modules is not None:
|
97 |
+
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
|
98 |
+
|
99 |
+
self.language_model = language_model
|
100 |
+
|
101 |
+
modules = [
|
102 |
+
nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
|
103 |
+
]
|
104 |
+
for _ in range(1, config.mlp_config.num_hidden_layers):
|
105 |
+
modules.append(nn.GELU())
|
106 |
+
modules.append(
|
107 |
+
nn.Linear(
|
108 |
+
config.text_config.hidden_size, config.text_config.hidden_size
|
109 |
+
)
|
110 |
+
)
|
111 |
+
self.mlp = nn.Sequential(*modules)
|
112 |
+
|
113 |
+
# Initialize weights and apply final processing
|
114 |
+
self.post_init()
|
115 |
+
|
116 |
+
def get_input_embeddings(self):
|
117 |
+
return self.language_model.get_input_embeddings()
|
118 |
+
|
119 |
+
def set_input_embeddings(self, value):
|
120 |
+
self.language_model.set_input_embeddings(value)
|
121 |
+
|
122 |
+
def set_output_embeddings(self, new_embeddings):
|
123 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
124 |
+
|
125 |
+
def get_output_embeddings(self) -> nn.Module:
|
126 |
+
return self.language_model.get_output_embeddings()
|
127 |
+
|
128 |
+
def get_encoder(self):
|
129 |
+
return self.language_model.get_encoder()
|
130 |
+
|
131 |
+
def get_decoder(self):
|
132 |
+
return self.language_model.get_decoder()
|
133 |
+
|
134 |
+
def _tie_weights(self):
|
135 |
+
if not self.config.use_decoder_only_language_model:
|
136 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
137 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
138 |
+
|
139 |
+
def _preprocess_accelerate(self):
|
140 |
+
r"""
|
141 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
142 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
143 |
+
"""
|
144 |
+
hf_device_map = self.hf_device_map
|
145 |
+
|
146 |
+
if (
|
147 |
+
len(hf_device_map) > 1
|
148 |
+
and "language_model" not in hf_device_map
|
149 |
+
and torch.cuda.device_count() > 1
|
150 |
+
):
|
151 |
+
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
|
152 |
+
logger.warning(
|
153 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
154 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
155 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
156 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
157 |
+
" more details on creating a `device_map` for large models.",
|
158 |
+
)
|
159 |
+
|
160 |
+
if hasattr(self.language_model, "_hf_hook"):
|
161 |
+
self.language_model._hf_hook.io_same_device = (
|
162 |
+
True # For `generate` compatibility
|
163 |
+
)
|
164 |
+
|
165 |
+
def forward(
|
166 |
+
self,
|
167 |
+
pixel_values: torch.FloatTensor,
|
168 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
169 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
170 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
171 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
172 |
+
output_attentions: Optional[bool] = None,
|
173 |
+
output_hidden_states: Optional[bool] = None,
|
174 |
+
labels: Optional[torch.LongTensor] = None,
|
175 |
+
return_dict: Optional[bool] = None,
|
176 |
+
) -> Union[Tuple, LlavaForConditionalGenerationModelOutput]:
|
177 |
+
return_dict = (
|
178 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
179 |
+
)
|
180 |
+
|
181 |
+
# step 1: forward the images through the vision encoder,
|
182 |
+
vision_outputs = self.vision_model(
|
183 |
+
pixel_values=pixel_values,
|
184 |
+
output_attentions=output_attentions,
|
185 |
+
return_dict=return_dict,
|
186 |
+
output_hidden_states=True,
|
187 |
+
)
|
188 |
+
# (bsz, seq len, hidden_size)
|
189 |
+
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
|
190 |
+
if self.config.vision_select_feature == "patch":
|
191 |
+
image_embeds = image_embeds[:, 1:]
|
192 |
+
elif self.config.vision_select_feature == "cls_patch":
|
193 |
+
image_embeds = image_embeds
|
194 |
+
else:
|
195 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
196 |
+
|
197 |
+
# step 2: forward the image embeddings through the mlp
|
198 |
+
image_embeds = self.mlp(image_embeds)
|
199 |
+
image_attention_mask = torch.ones(
|
200 |
+
image_embeds.size()[:-1], device=image_embeds.device
|
201 |
+
)
|
202 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
203 |
+
|
204 |
+
# step 3: concatenate
|
205 |
+
inputs_embeds = torch.cat(
|
206 |
+
[image_embeds, inputs_embeds.to(image_embeds.device)],
|
207 |
+
dim=1,
|
208 |
+
)
|
209 |
+
|
210 |
+
if attention_mask is None:
|
211 |
+
attention_mask = torch.ones_like(input_ids, device=input_ids.device)
|
212 |
+
|
213 |
+
attention_mask = torch.cat(
|
214 |
+
[image_attention_mask.to(attention_mask.device), attention_mask],
|
215 |
+
dim=1,
|
216 |
+
)
|
217 |
+
|
218 |
+
if self.config.use_decoder_only_language_model:
|
219 |
+
outputs = self.language_model(
|
220 |
+
inputs_embeds=inputs_embeds,
|
221 |
+
attention_mask=attention_mask,
|
222 |
+
output_attentions=output_attentions,
|
223 |
+
output_hidden_states=output_hidden_states,
|
224 |
+
return_dict=return_dict,
|
225 |
+
)
|
226 |
+
logits = outputs.logits if return_dict else outputs[0]
|
227 |
+
loss = None
|
228 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
229 |
+
if labels is not None:
|
230 |
+
labels = labels.to(logits.device)
|
231 |
+
logits = logits[:, -labels.size(1) :, :]
|
232 |
+
# Shift so that tokens < n predict n
|
233 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
234 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
235 |
+
|
236 |
+
# Flatten the tokens
|
237 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
238 |
+
|
239 |
+
loss = loss_fct(
|
240 |
+
shift_logits.view(-1, self.config.text_config.vocab_size),
|
241 |
+
shift_labels.view(-1),
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
outputs = self.language_model(
|
245 |
+
inputs_embeds=inputs_embeds,
|
246 |
+
attention_mask=attention_mask,
|
247 |
+
decoder_input_ids=decoder_input_ids,
|
248 |
+
decoder_attention_mask=decoder_attention_mask,
|
249 |
+
output_attentions=output_attentions,
|
250 |
+
output_hidden_states=output_hidden_states,
|
251 |
+
return_dict=return_dict,
|
252 |
+
labels=labels,
|
253 |
+
)
|
254 |
+
loss = outputs.loss if return_dict else outputs[0]
|
255 |
+
logits = outputs.logits if return_dict else outputs[1]
|
256 |
+
|
257 |
+
if not return_dict:
|
258 |
+
output = (logits, vision_outputs, outputs)
|
259 |
+
return ((loss,) + output) if loss is not None else output
|
260 |
+
|
261 |
+
return LlavaForConditionalGenerationModelOutput(
|
262 |
+
loss=loss,
|
263 |
+
logits=logits,
|
264 |
+
vision_outputs=vision_outputs,
|
265 |
+
language_model_outputs=outputs,
|
266 |
+
)
|
267 |
+
|
268 |
+
def get_image_embeds(self, pixel_values: torch.FloatTensor):
|
269 |
+
vision_outputs = self.vision_model(
|
270 |
+
pixel_values=pixel_values,
|
271 |
+
output_hidden_states=True,
|
272 |
+
)
|
273 |
+
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
|
274 |
+
if self.config.vision_select_feature == "patch":
|
275 |
+
image_embeds = image_embeds[:, 1:]
|
276 |
+
elif self.config.vision_select_feature == "cls_patch":
|
277 |
+
image_embeds = image_embeds
|
278 |
+
else:
|
279 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
280 |
+
|
281 |
+
image_embeds = self.mlp(image_embeds)
|
282 |
+
image_attention_mask = torch.ones(
|
283 |
+
image_embeds.size()[:-1], device=image_embeds.device
|
284 |
+
)
|
285 |
+
return dict(
|
286 |
+
image_embeds=image_embeds,
|
287 |
+
image_attention_mask=image_attention_mask,
|
288 |
+
)
|
289 |
+
|
290 |
+
def prepare_for_lm_generation(
|
291 |
+
self,
|
292 |
+
pixel_values: torch.FloatTensor,
|
293 |
+
input_ids: Optional[torch.LongTensor] = None,
|
294 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
295 |
+
):
|
296 |
+
batch_size = pixel_values.shape[0]
|
297 |
+
vision_outputs = self.get_image_embeds(pixel_values)
|
298 |
+
image_embeds = vision_outputs["image_embeds"]
|
299 |
+
image_attention_mask = vision_outputs["image_attention_mask"]
|
300 |
+
|
301 |
+
if input_ids is None:
|
302 |
+
input_ids = (
|
303 |
+
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
304 |
+
.repeat(batch_size, 1)
|
305 |
+
.to(image_embeds.device)
|
306 |
+
)
|
307 |
+
if attention_mask is None:
|
308 |
+
attention_mask = torch.ones_like(input_ids)
|
309 |
+
attention_mask = torch.cat(
|
310 |
+
[
|
311 |
+
image_attention_mask,
|
312 |
+
attention_mask.to(image_attention_mask.device),
|
313 |
+
],
|
314 |
+
dim=1,
|
315 |
+
)
|
316 |
+
|
317 |
+
# concatenate query embeddings with prompt embeddings
|
318 |
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
319 |
+
inputs_embeds = torch.cat(
|
320 |
+
[image_embeds, inputs_embeds.to(image_embeds.device)],
|
321 |
+
dim=1,
|
322 |
+
)
|
323 |
+
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
324 |
+
|
325 |
+
@torch.no_grad()
|
326 |
+
def generate(
|
327 |
+
self,
|
328 |
+
pixel_values: torch.FloatTensor,
|
329 |
+
input_ids: Optional[torch.LongTensor] = None,
|
330 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
331 |
+
**generate_kwargs,
|
332 |
+
) -> torch.LongTensor:
|
333 |
+
if hasattr(self, "hf_device_map"):
|
334 |
+
# preprocess for `accelerate`
|
335 |
+
self._preprocess_accelerate()
|
336 |
+
encodings = self.prepare_for_lm_generation(
|
337 |
+
pixel_values=pixel_values,
|
338 |
+
input_ids=input_ids,
|
339 |
+
attention_mask=attention_mask,
|
340 |
+
)
|
341 |
+
outputs = self.language_model.generate(
|
342 |
+
**encodings,
|
343 |
+
**generate_kwargs,
|
344 |
+
)
|
345 |
+
return outputs
|