g-h-chen commited on
Commit
34b5eca
1 Parent(s): 6efcca1

upload modeling_llava_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_llava_phi3.py +334 -0
modeling_llava_phi3.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import math
21
+ import sys
22
+ import pdb
23
+ from typing import Dict, Any
24
+
25
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
26
+ # MistralConfig, MistralModel, MistralForCausalLM
27
+
28
+
29
+ from transformers.modeling_outputs import CausalLMOutputWithPast
30
+
31
+
32
+ from transformers.cache_utils import Cache, DynamicCache
33
+
34
+
35
+ from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
36
+ from .modeling_phi3 import Phi3ForCausalLM, Phi3Model, Phi3Config
37
+ from .generation_utils import build_allava_input
38
+
39
+
40
+
41
+
42
+ ################ Phi ###############################
43
+
44
+ class LlavaPhi3Config(Phi3Config):
45
+ model_type = "llava_phi3"
46
+
47
+ class LlavaPhi3Model(LlavaMetaModel, Phi3Model):
48
+ config_class = LlavaPhi3Config
49
+
50
+ def __init__(self, config: Phi3Config):
51
+ super(LlavaPhi3Model, self).__init__(config)
52
+
53
+
54
+
55
+ class LlavaPhi3ForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM):
56
+ config_class = LlavaPhi3Config
57
+
58
+ def __init__(self, config, init_vision_encoder_from_ckpt=True):
59
+ config.flash_attn = True
60
+ config.flash_rotary = True
61
+ config.fused_dense = True
62
+ config._attn_implementation = "flash_attention_2"
63
+
64
+ super(Phi3ForCausalLM, self).__init__(config)
65
+ # self.model is used in LlavaMetaForCausalLM.get_model(); self.transformer is used in PhiForCausalLM.forward()
66
+ self.model = LlavaPhi3Model(config)
67
+ # self.model.embd =
68
+ if hasattr(self.model, '_use_flash_attention_2'):
69
+ assert self.model._use_flash_attention_2, 'flash attn is not enabled. check it out!'
70
+ # self.pretraining_tp = config.pretraining_tp
71
+ self.vocab_size = config.vocab_size
72
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
73
+
74
+ if init_vision_encoder_from_ckpt:
75
+ vision_tower = self.get_vision_tower()
76
+ print(f'loading from CLIP first. This should only be used at inference!!!')
77
+ vision_tower.load_model() #
78
+
79
+ # Initialize weights and apply final processing
80
+ self.post_init()
81
+
82
+ # ############ these two methods are missing in modeling_phi.py
83
+ # def get_input_embeddings(self) -> nn.Embedding:
84
+ # return self.model.embd.wte
85
+
86
+ # def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
87
+ # self.model.embd.wte = new_embeddings
88
+ # ############ these two methods are missing in modeling_phi.py
89
+
90
+ def get_model(self):
91
+ return self.model
92
+
93
+ def get_tokenizer(self):
94
+ return self.tokenizer
95
+
96
+ def get_processor(self):
97
+ return self.model.vision_tower.image_processor
98
+
99
+ def set_tokenizer_eos_id(self):
100
+ eos_token_id = 30027 # only for llava_phi3
101
+ self.tokenizer.eos_token_id = eos_token_id
102
+
103
+
104
+ def forward(
105
+ self,
106
+ input_ids: torch.LongTensor = None,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ position_ids: Optional[torch.LongTensor] = None,
109
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
110
+ inputs_embeds: Optional[torch.FloatTensor] = None,
111
+ labels: Optional[torch.LongTensor] = None,
112
+ use_cache: Optional[bool] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ images: Optional[torch.FloatTensor] = None,
116
+ return_dict: Optional[bool] = None,
117
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
118
+
119
+ # pdb.set_trace()
120
+ if inputs_embeds is None:
121
+ (
122
+ input_ids,
123
+ position_ids,
124
+ attention_mask,
125
+ past_key_values,
126
+ inputs_embeds,
127
+ labels
128
+ # ) = self.prepare_inputs_labels_for_multimodal(
129
+ ) = self.prepare_inputs_labels_for_multimodal_new(
130
+ input_ids,
131
+ position_ids,
132
+ attention_mask,
133
+ past_key_values,
134
+ labels,
135
+ images
136
+ )
137
+
138
+
139
+ return super().forward(
140
+ input_ids=input_ids,
141
+ attention_mask=attention_mask,
142
+ position_ids=position_ids,
143
+ past_key_values=past_key_values,
144
+ inputs_embeds=inputs_embeds,
145
+ labels=labels,
146
+ use_cache=use_cache,
147
+ output_attentions=output_attentions,
148
+ output_hidden_states=output_hidden_states,
149
+ return_dict=return_dict
150
+ )
151
+
152
+ @torch.no_grad()
153
+ def generate(
154
+ self,
155
+ inputs: Optional[torch.Tensor] = None,
156
+ images: Optional[torch.Tensor] = None,
157
+ **kwargs,
158
+ ) :
159
+ position_ids = kwargs.pop("position_ids", None)
160
+ attention_mask = kwargs.pop("attention_mask", None)
161
+ if "inputs_embeds" in kwargs:
162
+ raise NotImplementedError("`inputs_embeds` is not supported")
163
+
164
+ if images is not None:
165
+ (
166
+ inputs,
167
+ position_ids,
168
+ attention_mask,
169
+ _,
170
+ inputs_embeds,
171
+ _
172
+ ) = self.prepare_inputs_labels_for_multimodal_new(
173
+ inputs,
174
+ position_ids,
175
+ attention_mask,
176
+ None,
177
+ None,
178
+ images
179
+ )
180
+ else:
181
+ inputs_embeds = self.get_model().embed_tokens(inputs)
182
+
183
+ # print(inputs_embeds.shape)
184
+ return super().generate(
185
+ position_ids=None,
186
+ attention_mask=None,
187
+ inputs_embeds=inputs_embeds,
188
+ **kwargs
189
+ )
190
+
191
+
192
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
193
+ '''
194
+ This function is called for each token at inference
195
+ '''
196
+ # pdb.set_trace()
197
+ images = kwargs.pop("images", None)
198
+
199
+ ####################################################
200
+ # lines from modeling_phi.py
201
+ ####################################################
202
+
203
+ if past_key_values is not None:
204
+ if isinstance(past_key_values, Cache):
205
+ cache_length = past_key_values.get_seq_length()
206
+ past_length = past_key_values.seen_tokens
207
+ max_cache_length = past_key_values.get_max_length()
208
+ else:
209
+ cache_length = past_length = past_key_values[0][0].shape[2]
210
+ max_cache_length = None
211
+
212
+ # Keep only the unprocessed tokens:
213
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
214
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
215
+ # input)
216
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
217
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
218
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
219
+ # input_ids based on the past_length.
220
+ elif past_length < input_ids.shape[1]:
221
+ input_ids = input_ids[:, past_length:]
222
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
223
+ elif past_length >= input_ids.shape[1]:
224
+ input_ids = input_ids[:, [-1]] # only keep the last one!
225
+
226
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
227
+ if (
228
+ max_cache_length is not None
229
+ and attention_mask is not None
230
+ and cache_length + input_ids.shape[1] > max_cache_length
231
+ ):
232
+ attention_mask = attention_mask[:, -max_cache_length:]
233
+
234
+ position_ids = kwargs.get("position_ids", None)
235
+ if attention_mask is not None and position_ids is None:
236
+ # create position_ids on the fly for batch generation
237
+ position_ids = attention_mask.long().cumsum(-1) - 1
238
+ position_ids.masked_fill_(attention_mask == 0, 1)
239
+ if past_key_values:
240
+ position_ids = position_ids[:, -input_ids.shape[1] :]
241
+
242
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
243
+ if inputs_embeds is not None and past_key_values is None:
244
+ model_inputs = {"inputs_embeds": inputs_embeds}
245
+ else:
246
+ model_inputs = {"input_ids": input_ids}
247
+
248
+ model_inputs.update(
249
+ {
250
+ "position_ids": position_ids,
251
+ "past_key_values": past_key_values,
252
+ "use_cache": kwargs.get("use_cache"),
253
+ "attention_mask": attention_mask,
254
+ }
255
+ )
256
+ ####################################################
257
+ # end of lines from modeling_phi.py
258
+ ####################################################
259
+
260
+
261
+ if images is not None:
262
+ model_inputs['images'] = images
263
+ return model_inputs
264
+
265
+
266
+ # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
267
+ # images = kwargs.pop("images", None)
268
+ # _inputs = super().prepare_inputs_for_generation(
269
+ # input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
270
+ # )
271
+ # if images is not None:
272
+ # _inputs['images'] = images
273
+ # return _inputs
274
+
275
+ def chat(
276
+ self,
277
+ texts: Optional[str | list[list[str, str]]],
278
+ images: Optional[str | list[str]] = None,
279
+ history: Optional[list[str]] = None,
280
+ stream = False,
281
+ return_history = False,
282
+ **kwargs
283
+ ):
284
+ '''
285
+ texts: if `str`, then generate for a single round; if list[dict],
286
+ images: str (optional), local path to an image.
287
+ '''
288
+ use_cache = kwargs.pop('use_cache', True)
289
+
290
+ if 'eos_token_id' in kwargs:
291
+ _ = kwargs.pop('eos_token_id', None)
292
+ print(f'eos_token_id {_} from gen_kwargs is popped since it is not needed.')
293
+ # pdb.set_trace()
294
+
295
+
296
+ ############################
297
+ # merge history
298
+ ############################
299
+ input_ids, image_tensors, history = build_allava_input(
300
+ tokenizer = self.get_tokenizer(),
301
+ processor = self.get_processor(),
302
+ texts = texts,
303
+ images = images,
304
+ history=history,
305
+ return_history=return_history,
306
+ device = self.device
307
+ )
308
+
309
+ ############################
310
+ # generate response
311
+ ############################
312
+ # with torch.autocast(device_type='cuda'):
313
+ if 'cuda' in str(self.device):
314
+ device_type = 'cuda'
315
+ else:
316
+ device_type = 'cpu'
317
+
318
+ with torch.autocast(device_type=device_type, dtype=self.dtype):
319
+ output_ids = self.generate(
320
+ inputs=input_ids,
321
+ images=image_tensors,
322
+ use_cache=use_cache,
323
+ **kwargs)
324
+
325
+ answer = self.get_tokenizer().decode(output_ids[0, :], skip_special_tokens=True).strip()
326
+
327
+ if return_history:
328
+ history[-1][-1] = answer
329
+ return answer, history
330
+ return answer
331
+
332
+
333
+ AutoConfig.register("llava_phi3", LlavaPhi3Config)
334
+ AutoModelForCausalLM.register(LlavaPhi3Config, LlavaPhi3ForCausalLM)