Vision-CAIR commited on
Commit
0553a59
1 Parent(s): 9d1413c

Push model using huggingface_hub.

Browse files
Files changed (2) hide show
  1. config.json +14 -8
  2. mini_gpt4_llama_v2.py +758 -0
config.json CHANGED
@@ -1,8 +1,12 @@
1
  {
2
  "arch": "mini_gpt4_llama_v2",
3
  "architectures": [
4
- "MiniGPT4_llama_v2"
5
  ],
 
 
 
 
6
  "chat_template": true,
7
  "ckpt": "checkpoints/video_llama_checkpoint_last.pth",
8
  "device": "cuda",
@@ -14,23 +18,25 @@
14
  "length": 50,
15
  "llama_model": "meta-llama/Llama-2-7b-chat-hf",
16
  "lora_alpha": 16,
 
17
  "lora_r": 64,
 
 
 
 
18
  "low_resource": true,
19
  "max_context_len": 3600,
20
  "max_txt_len": 256,
21
  "model_type": "minigpt4_video",
22
  "num_query_token": 32,
23
  "prompt": "",
 
 
 
24
  "torch_dtype": "float32",
25
  "transformers_version": "4.42.3",
26
  "use_grad_checkpoint": true,
27
  "use_grad_checkpoint_llm": true,
28
- "vit_precision": "fp16",
29
  "vit_model": "eva_clip_g",
30
- "token_pooling": true,
31
- "lora_target_modules" : ["q_proj","v_proj"],
32
- "lora_dropout": 0.05,
33
- "remove_template": false,
34
- "prompt_path":""
35
-
36
  }
 
1
  {
2
  "arch": "mini_gpt4_llama_v2",
3
  "architectures": [
4
+ "MiniGPT4_Video"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "mini_gpt4_llama_v2.minigpt4_video_config",
8
+ "AutoModel": "mini_gpt4_llama_v2.MiniGPT4_Video"
9
+ },
10
  "chat_template": true,
11
  "ckpt": "checkpoints/video_llama_checkpoint_last.pth",
12
  "device": "cuda",
 
18
  "length": 50,
19
  "llama_model": "meta-llama/Llama-2-7b-chat-hf",
20
  "lora_alpha": 16,
21
+ "lora_dropout": 0.05,
22
  "lora_r": 64,
23
+ "lora_target_modules": [
24
+ "q_proj",
25
+ "v_proj"
26
+ ],
27
  "low_resource": true,
28
  "max_context_len": 3600,
29
  "max_txt_len": 256,
30
  "model_type": "minigpt4_video",
31
  "num_query_token": 32,
32
  "prompt": "",
33
+ "prompt_path": "",
34
+ "remove_template": false,
35
+ "token_pooling": true,
36
  "torch_dtype": "float32",
37
  "transformers_version": "4.42.3",
38
  "use_grad_checkpoint": true,
39
  "use_grad_checkpoint_llm": true,
 
40
  "vit_model": "eva_clip_g",
41
+ "vit_precision": "fp16"
 
 
 
 
 
42
  }
mini_gpt4_llama_v2.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import torch
5
+ from torch.cuda.amp import autocast as autocast
6
+ import torch.nn as nn
7
+
8
+ from minigpt4_video.registry import registry
9
+ from minigpt4_video.blip2 import Blip2Base, disabled_train
10
+ # from minigpt4_video.modeling_llama_v2 import LlamaForCausalLM as llm_model
11
+ # from minigpt4_video.modeling_mistral import MistralForCausalLM as llm_model
12
+ from minigpt4_video.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
13
+
14
+ from transformers import LlamaTokenizer
15
+ from transformers import BitsAndBytesConfig
16
+ from transformers import AutoConfig, AutoTokenizer
17
+ from peft import (
18
+ LoraConfig,
19
+ get_peft_model,
20
+ get_peft_model_state_dict,
21
+ prepare_model_for_int8_training,
22
+ set_peft_model_state_dict,
23
+ )
24
+ import time
25
+ import json
26
+ import numpy as np
27
+ import os
28
+ from transformers import PretrainedConfig
29
+ from transformers import PreTrainedModel
30
+ from typing import List
31
+ class minigpt4_video_config(PretrainedConfig):
32
+ model_type="minigpt4_video"
33
+ PRETRAINED_MODEL_CONFIG_DICT = {
34
+ "minigpt4_video": "configs/models/minigpt4.yaml",
35
+ }
36
+ def __init__(
37
+ self,
38
+ omg_config:dict = {},
39
+ **kwargs,
40
+ ):
41
+ for key, value in omg_config.items():
42
+ setattr(self, key, value)
43
+ super().__init__(**kwargs)
44
+
45
+ # def to_dict(self):
46
+ # output = super().to_dict()
47
+ # return output
48
+
49
+ @registry.register_model("mini_gpt4_llama_v2")
50
+ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
51
+ """
52
+ BLIP2 GPT-LLAMA model.
53
+ """
54
+
55
+ PRETRAINED_MODEL_CONFIG_DICT = {
56
+ "minigpt4_video": "minigpt4/configs/models/minigpt4.yaml",
57
+ }
58
+ config_class=minigpt4_video_config
59
+
60
+ def __init__(
61
+ self,
62
+ cfg={},
63
+ ):
64
+ ## loop through the config minigpt4_video_config object and set the attributes
65
+ if isinstance(cfg, minigpt4_video_config):
66
+ cfg = cfg.to_dict()
67
+
68
+ for key, value in cfg.items():
69
+ try:
70
+ setattr(self, key, value)
71
+ except:
72
+ print(f"Error setting attribute {key} with value {value}")
73
+ PreTrainedModel.__init__(self, minigpt4_video_config(cfg))
74
+ Blip2Base.__init__(self)
75
+ if "Mistral" in self.llama_model:
76
+ from minigpt4_video.modeling_mistral import MistralForCausalLM as llm_model
77
+ print("Mistral model")
78
+ self.model_type = "Mistral"
79
+ else:
80
+ from minigpt4_video.modeling_llama_v2 import LlamaForCausalLM as llm_model
81
+ print("Llama model")
82
+ self.model_type = "Llama"
83
+ self.tokenizer = self.init_tokenizer()
84
+
85
+ print("token pooling", self.token_pooling)
86
+ if self.freeze_vit:
87
+ # self.vit_precision="fp32"
88
+ print("vit precision", self.vit_precision)
89
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
90
+ self.vit_model, self.img_size, self.drop_path_rate, self.use_grad_checkpoint, self.vit_precision
91
+ )
92
+ for name, param in self.visual_encoder.named_parameters():
93
+ param.requires_grad = False
94
+ self.visual_encoder = self.visual_encoder.eval()
95
+ self.visual_encoder.train = disabled_train
96
+ for name, param in self.ln_vision.named_parameters():
97
+ param.requires_grad = False
98
+ self.ln_vision = self.ln_vision.eval()
99
+ self.ln_vision.train = disabled_train
100
+ logging.info("freeze vision encoder")
101
+ print("freeze the vision encoder")
102
+
103
+ else:
104
+ self.vit_precision="fp32"
105
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
106
+ self.vit_model, self.img_size, self.drop_path_rate, self.use_grad_checkpoint, self.vit_precision
107
+ )
108
+
109
+ print("unfreeze the vision encoder")
110
+ print('Loading VIT Done')
111
+
112
+ print('Loading LLAMA')
113
+
114
+ self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
115
+ token=os.environ.get("HF_TKN")
116
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(self.llama_model,use_fast=False,token=token) #
117
+ self.llama_tokenizer.pad_token = "$$"
118
+ # use fastv
119
+ self.use_fastv = False
120
+ print("self.low_resource",self.low_resource)
121
+ if self.low_resource:
122
+ self.llama_model = llm_model.from_pretrained(
123
+ self.llama_model,
124
+ torch_dtype=torch.float16,
125
+ # torch_dtype = torch.bfloat16,
126
+ load_in_8bit=True,
127
+ # device_map = "balanced"
128
+ # device_map="auto",
129
+ device_map={'':torch.cuda.current_device()},token=token
130
+ # device_map={'':0}
131
+
132
+ )
133
+ else:
134
+ self.llama_model = llm_model.from_pretrained(
135
+ self.llama_model,
136
+ torch_dtype=torch.float16,token=token
137
+ )
138
+
139
+ # self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
140
+ self.llama_model = prepare_model_for_int8_training(self.llama_model)
141
+ loraconfig = LoraConfig(
142
+ r=self.lora_r,
143
+ lora_alpha=self.lora_alpha,
144
+ target_modules=self.lora_target_modules,
145
+ lora_dropout=self.lora_dropout,
146
+ bias="none",
147
+ task_type="CAUSAL_LM"
148
+ )
149
+ self.llama_model = get_peft_model(self.llama_model, loraconfig)
150
+
151
+ self.llama_model.print_trainable_parameters()
152
+
153
+ if self.use_grad_checkpoint_llm:
154
+ self.llama_model.gradient_checkpointing_enable()
155
+
156
+ print('Loading LLAMA Done')
157
+
158
+
159
+ if self.token_pooling:
160
+ self.llama_proj = nn.Linear(
161
+ 1408*4, self.llama_model.config.hidden_size
162
+ )
163
+ else:
164
+ self.llama_proj = nn.Linear(
165
+ 1408, self.llama_model.config.hidden_size
166
+ )
167
+ if self.prompt_path:
168
+ with open(self.prompt_path, 'r') as f:
169
+ raw_prompts = f.read().splitlines()
170
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
171
+ self.prompt_list = [self.prompt_template.format(p) for p in filted_prompts]
172
+ print('Load {} training prompts'.format(len(self.prompt_list)))
173
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
174
+ else:
175
+ self.prompt_list = []
176
+
177
+ def encode_img(self, image):
178
+ device = image.device
179
+ if len(image.shape) > 4:
180
+ image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
181
+ with self.maybe_autocast():
182
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
183
+ image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
184
+ bs, pn, hs = image_embeds.shape
185
+ if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
186
+ image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
187
+
188
+ inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
189
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
190
+ return inputs_llama, atts_llama
191
+
192
+ def get_context_emb(self, prompt, img_list):
193
+ img_device = img_list[0].device
194
+ prompt_segs = prompt.split('<ImageHere>')
195
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
196
+ seg_tokens = [
197
+ self.llama_tokenizer(
198
+ seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
199
+ for i, seg in enumerate(prompt_segs)
200
+ ]
201
+
202
+ seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
203
+
204
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
205
+
206
+ mixed_embs = torch.cat(mixed_embs, dim=1)
207
+
208
+ return mixed_embs
209
+
210
+ def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
211
+ if prompts is None or len(prompts) == 0:
212
+ # prompts is not provided, just return the original image embedding
213
+ return img_embeds, atts_img
214
+ elif img_embeds is None:
215
+ # prompt is provided but there is no image embedding. return the prompt embedding in right padding
216
+ self.llama_tokenizer.padding_side = "right"
217
+ prompt_tokens = self.llama_tokenizer(
218
+ prompts,
219
+ return_tensors="pt",
220
+ padding="max_length",
221
+ add_special_tokens=False
222
+ ).to(self.device)
223
+ prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
224
+ atts_prompt = prompt_tokens.attention_mask
225
+ return prompt_embeds, atts_prompt
226
+
227
+ else:
228
+ # return the multi-modal embedding in right padding
229
+ emb_lists = []
230
+ if type(prompts) == str:
231
+ prompts = [prompts] * len(img_embeds)
232
+ for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
233
+ pn = each_img_embed.shape[-2]
234
+ if lengths is not None:
235
+ each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
236
+ each_img_embed = each_img_embed[:lengths[idx] * pn]
237
+
238
+ p_segs = each_prompt.split('<ImageHere>')
239
+ interleave_emb = []
240
+ for idx, seg in enumerate(p_segs[:-1]):
241
+ p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
242
+ p_embed = self.embed_tokens(p_tokens.input_ids)
243
+
244
+ interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
245
+
246
+ wrapped_emb = torch.cat(interleave_emb, dim=1)
247
+ p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
248
+ p_embed = self.embed_tokens(p_tokens.input_ids)
249
+ wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
250
+ emb_lists.append(wrapped_emb)
251
+
252
+ emb_lens = [emb.shape[1] for emb in emb_lists]
253
+ pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
254
+
255
+ # max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
256
+ max_length = self.max_context_len
257
+ wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
258
+ wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
259
+
260
+ for i, emb in enumerate(emb_lists):
261
+ length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
262
+ wrapped_embs[i, :length] = emb[:, :length]
263
+ wrapped_atts[i, :length] = 1
264
+
265
+ return wrapped_embs, wrapped_atts
266
+
267
+ def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
268
+ """
269
+ Concatenate the batched input embedding and batched output embedding together.
270
+ Both the input and the output embedding should be right padded.
271
+ """
272
+
273
+ input_lens = []
274
+ cat_embs = []
275
+ cat_atts = []
276
+
277
+ for i in range(input_embs.size(0)):
278
+ input_len = input_atts[i].sum()
279
+ input_lens.append(input_len)
280
+
281
+ cat_embs.append(
282
+ torch.cat([
283
+ input_embs[i][:input_len],
284
+ output_embs[i],
285
+ input_embs[i][input_len:]
286
+ ])
287
+ )
288
+ cat_atts.append(
289
+ torch.cat([
290
+ input_atts[i][:input_len],
291
+ output_atts[i],
292
+ input_atts[i][input_len:]
293
+ ])
294
+ )
295
+
296
+ cat_embs = torch.stack(cat_embs)
297
+ cat_atts = torch.stack(cat_atts)
298
+ return cat_embs, cat_atts, input_lens
299
+
300
+ def get_conv_emb(self, conv_q, conv_a, conv_img):
301
+ """concatenate conversation and make sure the model is only trained to regress the answer"""
302
+
303
+ regress_embs_list = []
304
+ targets_list = []
305
+
306
+ batch_size = len(conv_q)
307
+ for batch_idx in range(batch_size):
308
+ questions, answers = conv_q[batch_idx], conv_a[batch_idx]
309
+ assigned_imgs = conv_img[batch_idx]
310
+ questions = [self.prompt_wrap(
311
+ img_embeds=img,
312
+ atts_img=None,
313
+ prompts=[q],
314
+ lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
315
+ q_embs = [emb for emb, _ in questions]
316
+
317
+ answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
318
+ cur_emb = []
319
+ cur_target = []
320
+ for i in range(len(questions)):
321
+ cur_emb.append(q_embs[i])
322
+ cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
323
+
324
+ cur_emb.append(self.embed_tokens(answers[i].input_ids))
325
+ cur_target.append(answers[i].input_ids)
326
+
327
+ cur_emb = torch.cat(cur_emb, dim=1)
328
+ cur_target = torch.cat(cur_target, dim=1)
329
+
330
+ regress_embs_list.append(cur_emb)
331
+ targets_list.append(cur_target)
332
+
333
+ max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
334
+
335
+ regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
336
+ regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
337
+ targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
338
+
339
+ for batch_idx in range(batch_size):
340
+ cur_len = regress_embs_list[batch_idx].shape[1]
341
+ regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
342
+ regress_attn[batch_idx, :cur_len] = 1
343
+ targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
344
+
345
+ return regress_embeds, regress_attn, targets
346
+
347
+ def preparing_embedding(self, samples):
348
+ def remove_special_tokens(data):
349
+
350
+ # if "instruction_input" in data:
351
+ data = [instruct.replace(" [caption]","") for instruct in data]
352
+ data = [instruct.replace(" [vqa]","") for instruct in data]
353
+ data = [instruct.replace(" [grounding]","") for instruct in data]
354
+ data = [instruct.replace(" [identify]","") for instruct in data]
355
+ data = [instruct.replace(" [refer]","") for instruct in data]
356
+ return data
357
+
358
+ ### prepare input tokens
359
+ if 'image' in samples:
360
+ img_embeds, img_atts = self.encode_img(samples["image"])
361
+ else:
362
+ img_embeds = img_atts = None
363
+
364
+ if 'conv_q' in samples:
365
+ # handeling conversation datasets
366
+ conv_q, conv_a = samples['conv_q'], samples['conv_a']
367
+
368
+ connect_sym = samples['connect_sym'][0]
369
+ conv_q = [q.split(connect_sym)for q in conv_q]
370
+ conv_a = [a.split(connect_sym) for a in conv_a]
371
+ conv_img = assign_imgs(conv_q, img_embeds)
372
+
373
+ if self.chat_template:
374
+ conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
375
+
376
+ regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
377
+ cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
378
+
379
+ else:
380
+ if "instruction_input" in samples:
381
+ instruction = samples["instruction_input"]
382
+ elif len(self.prompt_list) > 1:
383
+ instruction = random.choice(self.prompt_list)
384
+ else:
385
+ instruction = None
386
+
387
+ if self.remove_template:
388
+ instruction = remove_special_tokens(instruction)
389
+
390
+ if self.chat_template:
391
+ instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
392
+
393
+ if 'length' in samples:
394
+ # the input is a image train (like videos)
395
+ bsz, pn, hs = img_embeds.shape
396
+ img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
397
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
398
+ else:
399
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
400
+
401
+ ### prepare target tokens
402
+ self.llama_tokenizer.padding_side = "right"
403
+ text = [t + self.end_sym for t in samples["answer"]]
404
+
405
+ regress_tokens = self.llama_tokenizer(
406
+ text,
407
+ return_tensors="pt",
408
+ padding="max_length",
409
+ truncation=True,
410
+ max_length=self.max_txt_len,
411
+ add_special_tokens=False
412
+ ).to(self.device)
413
+
414
+ regress_token_ids = regress_tokens.input_ids
415
+ regress_atts = regress_tokens.attention_mask
416
+ part_targets = regress_token_ids.masked_fill(
417
+ regress_token_ids == self.llama_tokenizer.pad_token_id, -100
418
+ )
419
+
420
+ regress_embeds = self.embed_tokens(regress_token_ids)
421
+
422
+ return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
423
+
424
+ def forward(self, samples, reduction="mean"):
425
+ # prepare the embedding to condition and the embedding to regress
426
+ cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
427
+ self.preparing_embedding(samples)
428
+
429
+ # concat the embedding to condition and the embedding to regress
430
+ inputs_embeds, attention_mask, input_lens = \
431
+ self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
432
+ # get bos token embedding
433
+ bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
434
+ bos_embeds = self.embed_tokens(bos)
435
+ bos_atts = attention_mask[:, :1]
436
+
437
+ # add bos token at the begining
438
+ inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
439
+ attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
440
+
441
+ targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
442
+ dtype=torch.long).to(self.device).fill_(-100)
443
+ for i, target in enumerate(part_targets):
444
+ targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
445
+
446
+ with self.maybe_autocast():
447
+ outputs = self.llama_model(
448
+ inputs_embeds=inputs_embeds,
449
+ attention_mask=attention_mask,
450
+ return_dict=True,
451
+ labels=targets,
452
+ reduction=reduction,
453
+ use_fastv=self.use_fastv
454
+ )
455
+ loss = outputs.loss
456
+
457
+ return {"loss": loss}
458
+
459
+ @torch.no_grad()
460
+ def generate(
461
+ self,
462
+ images,
463
+ texts,
464
+ use_nucleus_sampling=False,
465
+ num_beams=1,
466
+ max_new_tokens=20,
467
+ min_length=1,
468
+ top_p=0.9,
469
+ repetition_penalty=1.5,
470
+ length_penalty=1,
471
+ temperature=1,
472
+ do_sample=False,
473
+ stop_words_ids=[2],
474
+ lengths=None,
475
+ return_video_temporal_features=False,
476
+ img_embeds=None,
477
+ ):
478
+ '''
479
+ function for generate test use
480
+ '''
481
+
482
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
483
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
484
+ if img_embeds is None:
485
+ img_embeds, atts_img = self.encode_img(images.to(self.device))
486
+ else:
487
+ # Use images features from the input(4,45,64,5632)
488
+ img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
489
+ img_embeds= img_embeds.to(self.device)
490
+ img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
491
+ atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
492
+
493
+ if lengths is not None:
494
+ image_lists = []
495
+ img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
496
+ for idx, img_embed in enumerate(img_embeds):
497
+ image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
498
+ else:
499
+ image_lists = [[image_emb[None]] for image_emb in img_embeds]
500
+ assert len(texts) == len(image_lists)
501
+ batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
502
+
503
+ batch_size = len(batch_embs)
504
+ max_len = max([emb.shape[1] for emb in batch_embs])
505
+ emb_dim = batch_embs[0].shape[2]
506
+ dtype = batch_embs[0].dtype
507
+ device = batch_embs[0].device
508
+
509
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
510
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
511
+ for i, emb in enumerate(batch_embs):
512
+ emb_len = emb.shape[1]
513
+ embs[i, -emb_len:] = emb[0]
514
+ attn_mask[i, -emb_len:] = 1
515
+ # check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
516
+ if self.model_type == "Llama":
517
+ context_window = 3700
518
+ else:
519
+ context_window = 7500
520
+ if embs.shape[1] > context_window:
521
+ embs = embs[:, -context_window:]
522
+ attn_mask = attn_mask[:, -context_window:]
523
+ with self.maybe_autocast():
524
+ if return_video_temporal_features:
525
+ last_hidden_state = self.llama_model(
526
+ inputs_embeds=embs,
527
+ attention_mask=attn_mask,
528
+ output_hidden_states=True,
529
+ ).hidden_states[-1]
530
+ video_temporal_features = last_hidden_state.mean(dim=1)
531
+ # normalize the temporal features using L2 norm
532
+ # video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
533
+ outputs = self.llama_model.generate(
534
+ inputs_embeds=embs,
535
+ attention_mask=attn_mask,
536
+ max_new_tokens=max_new_tokens,
537
+ num_beams=num_beams,
538
+ do_sample=do_sample,
539
+ temperature=temperature,
540
+ repetition_penalty=repetition_penalty,
541
+ # stopping_criteria=stopping_criteria,
542
+ use_fastv=False,
543
+ )
544
+
545
+ answers = []
546
+ for output_token in outputs:
547
+ if output_token[0] == 0:
548
+ output_token = output_token[1:]
549
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
550
+ output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
551
+ output_texts = output_texts.replace("<s>", "")
552
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
553
+ answers.append(output_texts)
554
+ if return_video_temporal_features:
555
+ return answers, video_temporal_features
556
+ else:
557
+ return answers
558
+
559
+ @torch.no_grad()
560
+ def generate_text_only(
561
+ self,
562
+ images,
563
+ seg_tokens,
564
+ use_nucleus_sampling=False,
565
+ num_beams=1,
566
+ max_new_tokens=20,
567
+ min_length=1,
568
+ top_p=0.9,
569
+ repetition_penalty=1.5,
570
+ length_penalty=1,
571
+ temperature=1,
572
+ do_sample=False,
573
+ stop_words_ids=[2],
574
+ lengths=None,
575
+ return_video_temporal_features=False,
576
+ img_embeds=None,
577
+ ):
578
+ '''
579
+ function for generate test use
580
+ '''
581
+
582
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
583
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
584
+
585
+ batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
586
+
587
+ batch_size = len(batch_embs)
588
+ max_len = max([emb.shape[1] for emb in batch_embs])
589
+ emb_dim = batch_embs[0].shape[2]
590
+ dtype = batch_embs[0].dtype
591
+ device = batch_embs[0].device
592
+
593
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
594
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
595
+ for i, emb in enumerate(batch_embs):
596
+ emb_len = emb.shape[1]
597
+ embs[i, -emb_len:] = emb[0]
598
+ attn_mask[i, -emb_len:] = 1
599
+
600
+ with self.maybe_autocast():
601
+ outputs = self.llama_model.generate(
602
+ inputs_embeds=embs,
603
+ attention_mask=attn_mask,
604
+ max_new_tokens=max_new_tokens,
605
+ num_beams=num_beams,
606
+ do_sample=do_sample,
607
+ temperature=temperature,
608
+ repetition_penalty=repetition_penalty,
609
+ # stopping_criteria=stopping_criteria,
610
+ )
611
+
612
+ answers = []
613
+ for output_token in outputs:
614
+ if output_token[0] == 0:
615
+ output_token = output_token[1:]
616
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
617
+ output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
618
+ output_texts = output_texts.replace("<s>", "")
619
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
620
+ answers.append(output_texts)
621
+ return answers
622
+
623
+
624
+
625
+ @torch.no_grad()
626
+ def multi_select(self, images, texts, answers, num_cand=None):
627
+ all_losses = []
628
+ for answer in answers:
629
+ choice_samples = {
630
+ 'image': images,
631
+ 'instruction_input': texts,
632
+ 'answer': answer
633
+ }
634
+ loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
635
+ all_losses.append(loss)
636
+ torch.cuda.empty_cache()
637
+ all_losses = torch.cat(all_losses, dim=-1)
638
+ if num_cand is not None:
639
+ for i in range(all_losses.shape[0]):
640
+ all_losses[i, num_cand[i]:] = 9999
641
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
642
+ return output_class_ranks.tolist()
643
+
644
+ def predict_answers(
645
+ self,
646
+ samples,
647
+ num_beams=5,
648
+ inference_method="generate",
649
+ max_len=10,
650
+ min_len=1,
651
+ num_ans_candidates=128,
652
+ answer_list=None,
653
+ prompt="",
654
+ length_penalty=0,
655
+ **kwargs
656
+ ):
657
+ '''
658
+ function for open-ended VQA
659
+ '''
660
+ images = samples["image"].cuda()
661
+ texts = samples["instruction_input"]
662
+
663
+ output_text = self.generate(
664
+ images=images,
665
+ texts=texts,
666
+ num_beams=num_beams,
667
+ max_new_tokens=max_len,
668
+ min_length=min_len,
669
+ length_penalty=length_penalty
670
+ )
671
+
672
+ if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
673
+ output_text = self._lemmatize(output_text)
674
+
675
+ return output_text
676
+
677
+ def predict_class(
678
+ self,
679
+ samples,
680
+ num_beams=5,
681
+ inference_method="generate",
682
+ max_len=10,
683
+ min_len=1,
684
+ num_ans_candidates=5,
685
+ answer_list=None,
686
+ prompt="",
687
+ length_penalty=0,
688
+ **kwargs
689
+ ):
690
+ '''
691
+ function for multi-choice VQA
692
+ '''
693
+
694
+ image = samples["image"].cuda()
695
+ instruction = samples['instruction_input']
696
+ answers = samples["choices"]
697
+ num_cand = samples["num_choices"]
698
+
699
+ ranks = self.multi_select(image, instruction, answers, num_cand)
700
+
701
+ pred_ans = []
702
+ for i, rank in enumerate(ranks):
703
+ pred = answers[rank[0]][i]
704
+ pred_ans.append(pred)
705
+ return pred_ans
706
+
707
+ def embed_tokens(self, token_ids):
708
+ try:
709
+ embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
710
+ except AttributeError:
711
+ embeds = self.llama_model.model.embed_tokens(token_ids)
712
+
713
+ return embeds
714
+
715
+ @classmethod
716
+ def from_config(cls, cfg):
717
+ model = cls(
718
+ cfg=cfg,
719
+ )
720
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
721
+ if ckpt_path:
722
+ print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
723
+ ckpt = torch.load(ckpt_path, map_location="cpu")
724
+ msg = model.load_state_dict(ckpt['model'], strict=False)
725
+ # push the model to the hub with its metadata and config file
726
+ # model.push_to_hub("MiniGPT4-video-v2")
727
+ # video_config = minigpt4_video_config(cfg)
728
+ # video_config.save_pretrained("minigpt4_video_config")
729
+ # print("Save Minigpt-4-LLM Config: minigpt4_video_config")
730
+ # video_config.push_to_hub("MiniGPT4-video")
731
+ return model
732
+
733
+
734
+ def assign_imgs(batched_instruct_list, batched_img_embeds):
735
+ '''this function is used when the data is interleaved.
736
+ the interlevaed data is separated, and this function assign
737
+ corresponding image embeddings to each segment'''
738
+ if len(batched_img_embeds.shape) == 3:
739
+ batched_img_embeds = batched_img_embeds[:, None]
740
+
741
+ batched_assigned = []
742
+
743
+ for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
744
+ img_idx = 0
745
+ assigned_img = []
746
+ n_assigned = []
747
+ for instruct in instruct_list:
748
+ n_img = instruct.count('<ImageHere>')
749
+ if n_img > 0: # this instruction include images.
750
+ assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
751
+ img_idx += n_img
752
+ n_assigned.append(n_img)
753
+ else: # this instruction doesn't include images
754
+ assigned_img.append(None)
755
+ n_assigned.append(None)
756
+ batched_assigned.append(assigned_img)
757
+
758
+ return batched_assigned