Text Generation
Transformers
Safetensors
English
llava_phi
custom_code
g-h-chen commited on
Commit
971fc91
1 Parent(s): f185b0e

upload modeling_llava_phi.py

Browse files
Files changed (1) hide show
  1. modeling_llava_phi.py +289 -0
modeling_llava_phi.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import pdb
7
+ from typing import Dict, Any
8
+ from PIL import Image
9
+
10
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
11
+
12
+
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+ from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
16
+
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+
19
+ from transformers.generation.utils import GenerationConfig
20
+
21
+ import sys
22
+ from .modeling_phi import PhiForCausalLM, PhiModel, PhiConfig
23
+ from .generation_utils import build_allava_input
24
+
25
+
26
+
27
+
28
+ ################ Phi ###############################
29
+
30
+ class LlavaPhiConfig(PhiConfig):
31
+ model_type = "llava_phi"
32
+
33
+ class LlavaPhiModel(LlavaMetaModel, PhiModel):
34
+ config_class = LlavaPhiConfig
35
+
36
+ def __init__(self, config: PhiConfig):
37
+ super(LlavaPhiModel, self).__init__(config)
38
+
39
+
40
+
41
+ class LlavaPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaPhiConfig
43
+
44
+ def __init__(self, config, init_vision_encoder_from_ckpt=True):
45
+ # note that the default value is set to True for this inference version. In training `init_vision_encoder_from_ckpt` is default to be True.
46
+ config._attn_implementation = "flash_attention_2"
47
+
48
+ super(PhiForCausalLM, self).__init__(config)
49
+ # self.model is used in LlavaMetaForCausalLM.get_model(); self.transformer is used in PhiForCausalLM.forward()
50
+ self.model = LlavaPhiModel(config)
51
+ if hasattr(self.model, '_use_flash_attention_2'):
52
+ assert self.model._use_flash_attention_2, 'flash attn is not enabled. check it out!'
53
+ self.vocab_size = config.vocab_size
54
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
55
+
56
+ if init_vision_encoder_from_ckpt:
57
+ vision_tower = self.get_vision_tower()
58
+ print(f'loading from CLIP first. This should only be used at inference!!!')
59
+ vision_tower.load_model() #
60
+
61
+ # Initialize weights and apply final processing
62
+ self.post_init()
63
+
64
+ def get_model(self):
65
+ return self.model
66
+
67
+ def get_tokenizer(self):
68
+ return self.tokenizer
69
+
70
+ def get_processor(self):
71
+ return self.model.vision_tower.image_processor
72
+
73
+
74
+ def forward(
75
+ self,
76
+ input_ids: torch.LongTensor = None,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ position_ids: Optional[torch.LongTensor] = None,
79
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
80
+ inputs_embeds: Optional[torch.FloatTensor] = None,
81
+ labels: Optional[torch.LongTensor] = None,
82
+ use_cache: Optional[bool] = None,
83
+ output_attentions: Optional[bool] = None,
84
+ output_hidden_states: Optional[bool] = None,
85
+ images: Optional[torch.FloatTensor] = None,
86
+ return_dict: Optional[bool] = None,
87
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
88
+
89
+
90
+ if inputs_embeds is None:
91
+ (
92
+ input_ids,
93
+ position_ids,
94
+ attention_mask,
95
+ past_key_values,
96
+ inputs_embeds,
97
+ labels
98
+ # ) = self.prepare_inputs_labels_for_multimodal(
99
+ ) = self.prepare_inputs_labels_for_multimodal_new(
100
+ input_ids,
101
+ position_ids,
102
+ attention_mask,
103
+ past_key_values,
104
+ labels,
105
+ images
106
+ )
107
+
108
+ # pdb.set_trace()
109
+ return super().forward(
110
+ input_ids=input_ids,
111
+ attention_mask=attention_mask,
112
+ position_ids=position_ids,
113
+ past_key_values=past_key_values,
114
+ inputs_embeds=inputs_embeds,
115
+ labels=labels,
116
+ use_cache=use_cache,
117
+ output_attentions=output_attentions,
118
+ output_hidden_states=output_hidden_states,
119
+ return_dict=return_dict
120
+ )
121
+
122
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
123
+ '''
124
+ This function is called for each token at inference
125
+ '''
126
+ # pdb.set_trace()
127
+ images = kwargs.pop("images", None)
128
+
129
+ ####################################################
130
+ # lines from modeling_phi.py
131
+ ####################################################
132
+
133
+ if past_key_values is not None:
134
+ if isinstance(past_key_values, Cache):
135
+ cache_length = past_key_values.get_seq_length()
136
+ past_length = past_key_values.seen_tokens
137
+ max_cache_length = past_key_values.get_max_length()
138
+ else:
139
+ cache_length = past_length = past_key_values[0][0].shape[2]
140
+ max_cache_length = None
141
+
142
+ # Keep only the unprocessed tokens:
143
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
144
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
145
+ # input)
146
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
147
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
148
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
149
+ # input_ids based on the past_length.
150
+ elif past_length < input_ids.shape[1]:
151
+ input_ids = input_ids[:, past_length:]
152
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
153
+ elif past_length >= input_ids.shape[1]:
154
+ input_ids = input_ids[:, [-1]] # only keep the last one!
155
+
156
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
157
+ if (
158
+ max_cache_length is not None
159
+ and attention_mask is not None
160
+ and cache_length + input_ids.shape[1] > max_cache_length
161
+ ):
162
+ attention_mask = attention_mask[:, -max_cache_length:]
163
+
164
+ position_ids = kwargs.get("position_ids", None)
165
+ if attention_mask is not None and position_ids is None:
166
+ # create position_ids on the fly for batch generation
167
+ position_ids = attention_mask.long().cumsum(-1) - 1
168
+ position_ids.masked_fill_(attention_mask == 0, 1)
169
+ if past_key_values:
170
+ position_ids = position_ids[:, -input_ids.shape[1] :]
171
+
172
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
173
+ if inputs_embeds is not None and past_key_values is None:
174
+ model_inputs = {"inputs_embeds": inputs_embeds}
175
+ else:
176
+ model_inputs = {"input_ids": input_ids}
177
+
178
+ model_inputs.update(
179
+ {
180
+ "position_ids": position_ids,
181
+ "past_key_values": past_key_values,
182
+ "use_cache": kwargs.get("use_cache"),
183
+ "attention_mask": attention_mask,
184
+ }
185
+ )
186
+ ####################################################
187
+ # end of lines from modeling_phi.py
188
+ ####################################################
189
+
190
+
191
+ if images is not None:
192
+ model_inputs['images'] = images
193
+ return model_inputs
194
+
195
+
196
+ # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
197
+ # '''
198
+ # This function is called for each token at inference
199
+ # '''
200
+ # pdb.set_trace()
201
+ # images = kwargs.pop("images", None)
202
+
203
+
204
+ # _inputs = super().prepare_inputs_for_generation(
205
+ # input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
206
+ # )
207
+ # if images is not None:
208
+ # _inputs['images'] = images
209
+ # return _inputs
210
+
211
+ # def build_chat_input(self, text, images):
212
+
213
+ # return inputs
214
+
215
+ # def chat(self, tokenizer, messages: List[dict], stream=False,
216
+ # generation_config: Optional[GenerationConfig]=None):
217
+ # generation_config = generation_config or self.generation_config
218
+ # input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
219
+ # if stream:
220
+ # streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
221
+ # Thread(target=self.generate, kwargs=dict(
222
+ # inputs=input_ids, streamer=streamer,
223
+ # generation_config=generation_config,
224
+ # )).start()
225
+ # return streamer
226
+ # else:
227
+ # outputs = self.generate(input_ids, generation_config=generation_config)
228
+ # response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
229
+ # return response
230
+
231
+ # def collate_text_input(self, ):
232
+ # pass
233
+
234
+
235
+ def chat(
236
+ self,
237
+ texts: Optional[str | list[list[str, str]]],
238
+ images: Optional[str | list[str]] = None,
239
+ history: Optional[list[str]] = None,
240
+ stream = False,
241
+ return_history = False,
242
+ **kwargs
243
+ ):
244
+ '''
245
+ texts: if `str`, then generate for a single round; if list[dict],
246
+ images: str (optional), local path to an image.
247
+ '''
248
+ use_cache = kwargs.pop('use_cache', True)
249
+
250
+
251
+ ############################
252
+ # merge history
253
+ ############################
254
+ input_ids, image_tensors, history = build_allava_input(
255
+ tokenizer = self.get_tokenizer(),
256
+ processor = self.get_processor(),
257
+ texts = texts,
258
+ images = images,
259
+ history=history,
260
+ return_history=return_history,
261
+ device = self.device
262
+ )
263
+
264
+ ############################
265
+ # generate response
266
+ ############################
267
+ # with torch.autocast(device_type='cuda'):
268
+ if 'cuda' in str(self.device):
269
+ device_type = 'cuda'
270
+ else:
271
+ device_type = 'cpu'
272
+
273
+ with torch.autocast(device_type=device_type, dtype=self.dtype):
274
+ output_ids = self.generate(
275
+ inputs=input_ids,
276
+ images=image_tensors,
277
+ use_cache=use_cache,
278
+ **kwargs)
279
+
280
+ answer = self.get_tokenizer().decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
281
+
282
+ if return_history:
283
+ history[-1][-1] = answer
284
+ return answer, history
285
+ return answer
286
+
287
+
288
+ AutoConfig.register("llava_phi", LlavaPhiConfig)
289
+ AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM)