Vision-CAIR commited on
Commit
26ca17a
1 Parent(s): 8c345da

Upload folder using huggingface_hub

Browse files
Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from minigpt4_video.registry import registry
13
+ from minigpt4_video.base_model import BaseModel
14
+ from minigpt4_video.blip2 import Blip2Base
15
+ from minigpt4_video.base_processor import BaseProcessor
16
+ from minigpt4_video.mini_gpt4_llama_v2 import MiniGPT4_Video
17
+
18
+
19
+ __all__ = [
20
+ "load_model",
21
+ "BaseModel",
22
+ "Blip2Base",
23
+ "MiniGPT4_Video",
24
+ ]
25
+
26
+
27
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28
+ """
29
+ Load supported models.
30
+
31
+ To list all available models and types in registry:
32
+ >>> from minigpt4.models import model_zoo
33
+ >>> print(model_zoo)
34
+
35
+ Args:
36
+ name (str): name of the model.
37
+ model_type (str): type of the model.
38
+ is_eval (bool): whether the model is in eval mode. Default: False.
39
+ device (str): device to use. Default: "cpu".
40
+ checkpoint (str): path or to checkpoint. Default: None.
41
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
42
+
43
+ Returns:
44
+ model (torch.nn.Module): model.
45
+ """
46
+
47
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48
+
49
+ if checkpoint is not None:
50
+ model.load_checkpoint(checkpoint)
51
+
52
+ if is_eval:
53
+ model.eval()
54
+
55
+ if device == "cpu":
56
+ model = model.float()
57
+
58
+ return model.to(device)
59
+
60
+
61
+ def load_preprocess(config):
62
+ """
63
+ Load preprocessor configs and construct preprocessors.
64
+
65
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66
+
67
+ Args:
68
+ config (dict): preprocessor configs.
69
+
70
+ Returns:
71
+ vis_processors (dict): preprocessors for visual inputs.
72
+ txt_processors (dict): preprocessors for text inputs.
73
+
74
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
75
+ """
76
+
77
+ def _build_proc_from_cfg(cfg):
78
+ return (
79
+ registry.get_processor_class(cfg.name).from_config(cfg)
80
+ if cfg is not None
81
+ else BaseProcessor()
82
+ )
83
+
84
+ vis_processors = dict()
85
+ txt_processors = dict()
86
+
87
+ vis_proc_cfg = config.get("vis_processor")
88
+ txt_proc_cfg = config.get("text_processor")
89
+
90
+ if vis_proc_cfg is not None:
91
+ vis_train_cfg = vis_proc_cfg.get("train")
92
+ vis_eval_cfg = vis_proc_cfg.get("eval")
93
+ else:
94
+ vis_train_cfg = None
95
+ vis_eval_cfg = None
96
+
97
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99
+
100
+ if txt_proc_cfg is not None:
101
+ txt_train_cfg = txt_proc_cfg.get("train")
102
+ txt_eval_cfg = txt_proc_cfg.get("eval")
103
+ else:
104
+ txt_train_cfg = None
105
+ txt_eval_cfg = None
106
+
107
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109
+
110
+ return vis_processors, txt_processors
111
+
112
+
113
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114
+ """
115
+ Load model and its related preprocessors.
116
+
117
+ List all available models and types in registry:
118
+ >>> from minigpt4.models import model_zoo
119
+ >>> print(model_zoo)
120
+
121
+ Args:
122
+ name (str): name of the model.
123
+ model_type (str): type of the model.
124
+ is_eval (bool): whether the model is in eval mode. Default: False.
125
+ device (str): device to use. Default: "cpu".
126
+
127
+ Returns:
128
+ model (torch.nn.Module): model.
129
+ vis_processors (dict): preprocessors for visual inputs.
130
+ txt_processors (dict): preprocessors for text inputs.
131
+ """
132
+ model_cls = registry.get_model_class(name)
133
+
134
+ # load model
135
+ model = model_cls.from_pretrained(model_type=model_type)
136
+
137
+ if is_eval:
138
+ model.eval()
139
+
140
+ # load preprocess
141
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142
+ if cfg is not None:
143
+ preprocess_cfg = cfg.preprocess
144
+
145
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146
+ else:
147
+ vis_processors, txt_processors = None, None
148
+ logging.info(
149
+ f"""No default preprocess for model {name} ({model_type}).
150
+ This can happen if the model is not finetuned on downstream datasets,
151
+ or it is not intended for direct use without finetuning.
152
+ """
153
+ )
154
+
155
+ if device == "cpu" or device == torch.device("cpu"):
156
+ model = model.float()
157
+
158
+ return model.to(device), vis_processors, txt_processors
159
+
160
+
161
+ class ModelZoo:
162
+ """
163
+ A utility class to create string representation of available model architectures and types.
164
+
165
+ >>> from minigpt4.models import model_zoo
166
+ >>> # list all available models
167
+ >>> print(model_zoo)
168
+ >>> # show total number of models
169
+ >>> print(len(model_zoo))
170
+ """
171
+
172
+ def __init__(self) -> None:
173
+ self.model_zoo = {
174
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175
+ for k, v in registry.mapping["model_name_mapping"].items()
176
+ }
177
+
178
+ def __str__(self) -> str:
179
+ return (
180
+ "=" * 50
181
+ + "\n"
182
+ + f"{'Architectures':<30} {'Types'}\n"
183
+ + "=" * 50
184
+ + "\n"
185
+ + "\n".join(
186
+ [
187
+ f"{name:<30} {', '.join(types)}"
188
+ for name, types in self.model_zoo.items()
189
+ ]
190
+ )
191
+ )
192
+
193
+ def __iter__(self):
194
+ return iter(self.model_zoo.items())
195
+
196
+ def __len__(self):
197
+ return sum([len(v) for v in self.model_zoo.values()])
198
+
199
+
200
+ model_zoo = ModelZoo()
__pycache__/Qformer.cpython-310.pyc ADDED
Binary file (30.6 kB). View file
 
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
__pycache__/base_model.cpython-310.pyc ADDED
Binary file (8.22 kB). View file
 
__pycache__/base_processor.cpython-310.pyc ADDED
Binary file (1.36 kB). View file
 
__pycache__/blip2.cpython-310.pyc ADDED
Binary file (6.44 kB). View file
 
__pycache__/conversation.cpython-310.pyc ADDED
Binary file (7.25 kB). View file
 
__pycache__/dist_utils.cpython-310.pyc ADDED
Binary file (3.88 kB). View file
 
__pycache__/eva_vit.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
__pycache__/logger.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
__pycache__/mini_gpt4_llama_v2.cpython-310.pyc ADDED
Binary file (20.8 kB). View file
 
__pycache__/modeling_llama_v2.cpython-310.pyc ADDED
Binary file (4.29 kB). View file
 
__pycache__/registry.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
base_model.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from minigpt4_video.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from minigpt4_video.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+ from huggingface_hub import PyTorchModelHubMixin
19
+
20
+ class BaseModel(nn.Module,PyTorchModelHubMixin):
21
+ """Base class for models."""
22
+
23
+ def __init__(self):
24
+ PyTorchModelHubMixin.__init__(self)
25
+ nn.Module.__init__(self)
26
+
27
+ @property
28
+ def device(self):
29
+ return list(self.parameters())[0].device
30
+
31
+ def load_checkpoint(self, url_or_filename):
32
+ """
33
+ Load from a finetuned checkpoint.
34
+
35
+ This should expect no mismatch in the model keys and the checkpoint keys.
36
+ """
37
+
38
+ if is_url(url_or_filename):
39
+ cached_file = download_cached_file(
40
+ url_or_filename, check_hash=False, progress=True
41
+ )
42
+ checkpoint = torch.load(cached_file, map_location="cpu")
43
+ elif os.path.isfile(url_or_filename):
44
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
45
+ else:
46
+ raise RuntimeError("checkpoint url or path is invalid")
47
+
48
+ if "model" in checkpoint.keys():
49
+ state_dict = checkpoint["model"]
50
+ else:
51
+ state_dict = checkpoint
52
+
53
+ msg = self.load_state_dict(state_dict, strict=False)
54
+
55
+ logging.info("Missing keys {}".format(msg.missing_keys))
56
+ logging.info("load checkpoint from %s" % url_or_filename)
57
+
58
+ return msg
59
+
60
+ @classmethod
61
+ # def from_pretrained(cls, model_type):
62
+ # """
63
+ # Build a pretrained model from default configuration file, specified by model_type.
64
+
65
+ # Args:
66
+ # - model_type (str): model type, specifying architecture and checkpoints.
67
+
68
+ # Returns:
69
+ # - model (nn.Module): pretrained or finetuned model, depending on the configuration.
70
+ # """
71
+ # model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
72
+ # model = cls.from_config(model_cfg)
73
+
74
+ # return model
75
+
76
+ @classmethod
77
+ def default_config_path(cls, model_type):
78
+ assert (
79
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
80
+ ), "Unknown model type {}".format(model_type)
81
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
82
+
83
+ def load_checkpoint_from_config(self, cfg, **kwargs):
84
+ """
85
+ Load checkpoint as specified in the config file.
86
+
87
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
88
+ When loading the pretrained model, each task-specific architecture may define their
89
+ own load_from_pretrained() method.
90
+ """
91
+ load_finetuned = cfg.get("load_finetuned", True)
92
+ if load_finetuned:
93
+ finetune_path = cfg.get("finetuned", None)
94
+ assert (
95
+ finetune_path is not None
96
+ ), "Found load_finetuned is True, but finetune_path is None."
97
+ self.load_checkpoint(url_or_filename=finetune_path)
98
+ else:
99
+ # load pre-trained weights
100
+ pretrain_path = cfg.get("pretrained", None)
101
+ assert "Found load_finetuned is False, but pretrain_path is None."
102
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
103
+
104
+ def before_evaluation(self, **kwargs):
105
+ pass
106
+
107
+ def show_n_params(self, return_str=True):
108
+ tot = 0
109
+ for p in self.parameters():
110
+ w = 1
111
+ for x in p.shape:
112
+ w *= x
113
+ tot += w
114
+ if return_str:
115
+ if tot >= 1e6:
116
+ return "{:.1f}M".format(tot / 1e6)
117
+ else:
118
+ return "{:.1f}K".format(tot / 1e3)
119
+ else:
120
+ return tot
121
+
122
+
123
+ class BaseEncoder(nn.Module):
124
+ """
125
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
126
+ """
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+
131
+ def forward_features(self, samples, **kwargs):
132
+ raise NotImplementedError
133
+
134
+ @property
135
+ def device(self):
136
+ return list(self.parameters())[0].device
137
+
138
+
139
+ class SharedQueueMixin:
140
+ @torch.no_grad()
141
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
142
+ # gather keys before updating queue
143
+ image_feats = concat_all_gather(image_feat)
144
+ text_feats = concat_all_gather(text_feat)
145
+
146
+ batch_size = image_feats.shape[0]
147
+
148
+ ptr = int(self.queue_ptr)
149
+ assert self.queue_size % batch_size == 0 # for simplicity
150
+
151
+ # replace the keys at ptr (dequeue and enqueue)
152
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
153
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
154
+
155
+ if idxs is not None:
156
+ idxs = concat_all_gather(idxs)
157
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
158
+
159
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
160
+ self.queue_ptr[0] = ptr
161
+
162
+
163
+ class MomentumDistilationMixin:
164
+ @torch.no_grad()
165
+ def copy_params(self):
166
+ for model_pair in self.model_pairs:
167
+ for param, param_m in zip(
168
+ model_pair[0].parameters(), model_pair[1].parameters()
169
+ ):
170
+ param_m.data.copy_(param.data) # initialize
171
+ param_m.requires_grad = False # not update by gradient
172
+
173
+ @torch.no_grad()
174
+ def _momentum_update(self):
175
+ for model_pair in self.model_pairs:
176
+ for param, param_m in zip(
177
+ model_pair[0].parameters(), model_pair[1].parameters()
178
+ ):
179
+ param_m.data = param_m.data * self.momentum + param.data * (
180
+ 1.0 - self.momentum
181
+ )
182
+
183
+
184
+ class GatherLayer(torch.autograd.Function):
185
+ """
186
+ Gather tensors from all workers with support for backward propagation:
187
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
188
+ """
189
+
190
+ @staticmethod
191
+ def forward(ctx, x):
192
+ output = [
193
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
194
+ ]
195
+ torch.distributed.all_gather(output, x)
196
+ return tuple(output)
197
+
198
+ @staticmethod
199
+ def backward(ctx, *grads):
200
+ all_gradients = torch.stack(grads)
201
+ torch.distributed.all_reduce(all_gradients)
202
+ return all_gradients[torch.distributed.get_rank()]
203
+
204
+
205
+ def all_gather_with_grad(tensors):
206
+ """
207
+ Performs all_gather operation on the provided tensors.
208
+ Graph remains connected for backward grad computation.
209
+ """
210
+ # Queue the gathered tensors
211
+ world_size = torch.distributed.get_world_size()
212
+ # There is no need for reduction in the single-proc case
213
+ if world_size == 1:
214
+ return tensors
215
+
216
+ # tensor_all = GatherLayer.apply(tensors)
217
+ tensor_all = GatherLayer.apply(tensors)
218
+
219
+ return torch.cat(tensor_all, dim=0)
220
+
221
+
222
+ @torch.no_grad()
223
+ def concat_all_gather(tensor):
224
+ """
225
+ Performs all_gather operation on the provided tensors.
226
+ *** Warning ***: torch.distributed.all_gather has no gradient.
227
+ """
228
+ # if use distributed training
229
+ if not is_dist_avail_and_initialized():
230
+ return tensor
231
+
232
+ tensors_gather = [
233
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
234
+ ]
235
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
236
+
237
+ output = torch.cat(tensors_gather, dim=0)
238
+ return output
239
+
240
+
241
+ def tile(x, dim, n_tile):
242
+ init_dim = x.size(dim)
243
+ repeat_idx = [1] * x.dim()
244
+ repeat_idx[dim] = n_tile
245
+ x = x.repeat(*(repeat_idx))
246
+ order_index = torch.LongTensor(
247
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
248
+ )
249
+ return torch.index_select(x, dim, order_index.to(x.device))
base_processor.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ class BaseProcessor:
12
+ def __init__(self):
13
+ self.transform = lambda x: x
14
+ return
15
+
16
+ def __call__(self, item):
17
+ return self.transform(item)
18
+
19
+ @classmethod
20
+ def from_config(cls, cfg=None):
21
+ return cls()
22
+
23
+ def build(self, **kwargs):
24
+ cfg = OmegaConf.create(kwargs)
25
+
26
+ return self.from_config(cfg)
blip2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ from minigpt4_video import dist_utils as dist_utils
19
+ from minigpt4_video.dist_utils import download_cached_file
20
+ from minigpt4_video.utils import is_url
21
+ from minigpt4_video.logger import MetricLogger
22
+ from minigpt4_video.base_model import BaseModel
23
+ from minigpt4_video.Qformer import BertConfig, BertLMHeadModel
24
+ from minigpt4_video.eva_vit import create_eva_vit_g
25
+ from transformers import BertTokenizer
26
+
27
+
28
+ class Blip2Base(BaseModel):
29
+ @classmethod
30
+ def init_tokenizer(cls):
31
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33
+ return tokenizer
34
+
35
+ def maybe_autocast(self, dtype=torch.float16):
36
+ # if on cpu, don't use autocast
37
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38
+ enable_autocast = self.device != torch.device("cpu")
39
+
40
+ if enable_autocast:
41
+ return torch.cuda.amp.autocast(dtype=dtype)
42
+ else:
43
+ return contextlib.nullcontext()
44
+
45
+ @classmethod
46
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48
+ encoder_config.encoder_width = vision_width
49
+ # insert cross-attention layer every other block
50
+ encoder_config.add_cross_attention = True
51
+ encoder_config.cross_attention_freq = cross_attention_freq
52
+ encoder_config.query_length = num_query_token
53
+ Qformer = BertLMHeadModel(config=encoder_config)
54
+ query_tokens = nn.Parameter(
55
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
56
+ )
57
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58
+ return Qformer, query_tokens
59
+
60
+ @classmethod
61
+ def init_vision_encoder(
62
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63
+ ):
64
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
65
+ visual_encoder = create_eva_vit_g(
66
+ img_size, drop_path_rate, use_grad_checkpoint, precision
67
+ )
68
+
69
+ ln_vision = LayerNorm(visual_encoder.num_features)
70
+ return visual_encoder, ln_vision
71
+
72
+ def load_from_pretrained(self, url_or_filename):
73
+ if is_url(url_or_filename):
74
+ cached_file = download_cached_file(
75
+ url_or_filename, check_hash=False, progress=True
76
+ )
77
+ checkpoint = torch.load(cached_file, map_location="cpu")
78
+ elif os.path.isfile(url_or_filename):
79
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
80
+ else:
81
+ raise RuntimeError("checkpoint url or path is invalid")
82
+
83
+ state_dict = checkpoint["model"]
84
+
85
+ msg = self.load_state_dict(state_dict, strict=False)
86
+
87
+ # logging.info("Missing keys {}".format(msg.missing_keys))
88
+ logging.info("load checkpoint from %s" % url_or_filename)
89
+
90
+ return msg
91
+
92
+
93
+ def disabled_train(self, mode=True):
94
+ """Overwrite model.train with this function to make sure train/eval mode
95
+ does not change anymore."""
96
+ return self
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
106
+
107
+
108
+ def compute_sim_matrix(model, data_loader, **kwargs):
109
+ k_test = kwargs.pop("k_test")
110
+
111
+ metric_logger = MetricLogger(delimiter=" ")
112
+ header = "Evaluation:"
113
+
114
+ logging.info("Computing features for evaluation...")
115
+ start_time = time.time()
116
+
117
+ texts = data_loader.dataset.text
118
+ num_text = len(texts)
119
+ text_bs = 256
120
+ text_ids = []
121
+ text_embeds = []
122
+ text_atts = []
123
+ for i in range(0, num_text, text_bs):
124
+ text = texts[i : min(num_text, i + text_bs)]
125
+ text_input = model.tokenizer(
126
+ text,
127
+ padding="max_length",
128
+ truncation=True,
129
+ max_length=35,
130
+ return_tensors="pt",
131
+ ).to(model.device)
132
+ text_feat = model.forward_text(text_input)
133
+ text_embed = F.normalize(model.text_proj(text_feat))
134
+ text_embeds.append(text_embed)
135
+ text_ids.append(text_input.input_ids)
136
+ text_atts.append(text_input.attention_mask)
137
+
138
+ text_embeds = torch.cat(text_embeds, dim=0)
139
+ text_ids = torch.cat(text_ids, dim=0)
140
+ text_atts = torch.cat(text_atts, dim=0)
141
+
142
+ vit_feats = []
143
+ image_embeds = []
144
+ for samples in data_loader:
145
+ image = samples["image"]
146
+
147
+ image = image.to(model.device)
148
+ image_feat, vit_feat = model.forward_image(image)
149
+ image_embed = model.vision_proj(image_feat)
150
+ image_embed = F.normalize(image_embed, dim=-1)
151
+
152
+ vit_feats.append(vit_feat.cpu())
153
+ image_embeds.append(image_embed)
154
+
155
+ vit_feats = torch.cat(vit_feats, dim=0)
156
+ image_embeds = torch.cat(image_embeds, dim=0)
157
+
158
+ sims_matrix = []
159
+ for image_embed in image_embeds:
160
+ sim_q2t = image_embed @ text_embeds.t()
161
+ sim_i2t, _ = sim_q2t.max(0)
162
+ sims_matrix.append(sim_i2t)
163
+ sims_matrix = torch.stack(sims_matrix, dim=0)
164
+
165
+ score_matrix_i2t = torch.full(
166
+ (len(data_loader.dataset.image), len(texts)), -100.0
167
+ ).to(model.device)
168
+
169
+ num_tasks = dist_utils.get_world_size()
170
+ rank = dist_utils.get_rank()
171
+ step = sims_matrix.size(0) // num_tasks + 1
172
+ start = rank * step
173
+ end = min(sims_matrix.size(0), start + step)
174
+
175
+ for i, sims in enumerate(
176
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
177
+ ):
178
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180
+ score = model.compute_itm(
181
+ image_inputs=image_inputs,
182
+ text_ids=text_ids[topk_idx],
183
+ text_atts=text_atts[topk_idx],
184
+ ).float()
185
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186
+
187
+ sims_matrix = sims_matrix.t()
188
+ score_matrix_t2i = torch.full(
189
+ (len(texts), len(data_loader.dataset.image)), -100.0
190
+ ).to(model.device)
191
+
192
+ step = sims_matrix.size(0) // num_tasks + 1
193
+ start = rank * step
194
+ end = min(sims_matrix.size(0), start + step)
195
+
196
+ for i, sims in enumerate(
197
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
198
+ ):
199
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201
+ score = model.compute_itm(
202
+ image_inputs=image_inputs,
203
+ text_ids=text_ids[start + i].repeat(k_test, 1),
204
+ text_atts=text_atts[start + i].repeat(k_test, 1),
205
+ ).float()
206
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207
+
208
+ if dist_utils.is_dist_avail_and_initialized():
209
+ dist.barrier()
210
+ torch.distributed.all_reduce(
211
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212
+ )
213
+ torch.distributed.all_reduce(
214
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215
+ )
216
+
217
+ total_time = time.time() - start_time
218
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219
+ logging.info("Evaluation time {}".format(total_time_str))
220
+
221
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
blip2_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions,
15
+ CausalLMOutputWithCrossAttentions,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class BlipSimilarity(ModelOutput):
21
+ sim_i2t: torch.FloatTensor = None
22
+ sim_t2i: torch.FloatTensor = None
23
+
24
+ sim_i2t_m: Optional[torch.FloatTensor] = None
25
+ sim_t2i_m: Optional[torch.FloatTensor] = None
26
+
27
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
28
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class BlipIntermediateOutput(ModelOutput):
33
+ """
34
+ Data class for intermediate outputs of BLIP models.
35
+
36
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38
+
39
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41
+
42
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44
+
45
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
47
+
48
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50
+
51
+ """
52
+
53
+ # uni-modal features
54
+ image_embeds: torch.FloatTensor = None
55
+ text_embeds: Optional[torch.FloatTensor] = None
56
+
57
+ image_embeds_m: Optional[torch.FloatTensor] = None
58
+ text_embeds_m: Optional[torch.FloatTensor] = None
59
+
60
+ # intermediate outputs of multimodal encoder
61
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63
+
64
+ itm_logits: Optional[torch.FloatTensor] = None
65
+ itm_labels: Optional[torch.LongTensor] = None
66
+
67
+ # intermediate outputs of multimodal decoder
68
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69
+ decoder_labels: Optional[torch.LongTensor] = None
70
+
71
+
72
+ @dataclass
73
+ class BlipOutput(ModelOutput):
74
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75
+ sims: Optional[BlipSimilarity] = None
76
+
77
+ intermediate_output: BlipIntermediateOutput = None
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+
81
+ loss_itc: Optional[torch.FloatTensor] = None
82
+
83
+ loss_itm: Optional[torch.FloatTensor] = None
84
+
85
+ loss_lm: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ @dataclass
89
+ class BlipOutputFeatures(ModelOutput):
90
+ """
91
+ Data class of features from BlipFeatureExtractor.
92
+
93
+ Args:
94
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98
+
99
+ The first embedding or feature is for the [CLS] token.
100
+
101
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102
+ """
103
+
104
+ image_embeds: Optional[torch.FloatTensor] = None
105
+ image_embeds_proj: Optional[torch.FloatTensor] = None
106
+
107
+ text_embeds: Optional[torch.FloatTensor] = None
108
+ text_embeds_proj: Optional[torch.FloatTensor] = None
109
+
110
+ multimodal_embeds: Optional[torch.FloatTensor] = None
clip_vision_encoder.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionEncoder(nn.Module):
8
+ def __init__(self, encoder_name="openai/clip-vit-large-patch14", delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_encoder_name = encoder_name
14
+ # self.select_layer = args.mm_vision_select_layer
15
+ # self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+ self.select_layer = -1
17
+ self.select_feature = "patch"
18
+ if not delay_load:
19
+ self.load_model()
20
+ else:
21
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
22
+
23
+ def load_model(self):
24
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
25
+ self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name)
26
+ self.vision_encoder.requires_grad_(False)
27
+
28
+ self.is_loaded = True
29
+
30
+ def feature_select(self, image_forward_outs):
31
+ image_features = image_forward_outs.hidden_states[self.select_layer]
32
+ if self.select_feature == 'patch':
33
+ image_features = image_features[:, :]
34
+ elif self.select_feature == 'cls_patch':
35
+ image_features = image_features
36
+ else:
37
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
38
+ return image_features
39
+
40
+ @torch.no_grad()
41
+ def forward(self, images):
42
+ if type(images) is list:
43
+ image_features = []
44
+ for image in images:
45
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
46
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
47
+ image_features.append(image_feature)
48
+ else:
49
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
50
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
51
+ # print("image feature shape", image_features.shape)
52
+ # print(type(image_forward_outs))
53
+ # print(type(image_forward_outs.shape))
54
+ # image_features = image_forward_outs.to(images.dtype)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.vision_encoder.dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.vision_encoder.device
69
+
70
+ @property
71
+ def config(self):
72
+ if self.is_loaded:
73
+ return self.vision_encoder.config
74
+ else:
75
+ return self.cfg_only
76
+
77
+ @property
78
+ def hidden_size(self):
79
+ return self.config.hidden_size
80
+
81
+ @property
82
+ def num_patches(self):
83
+ return (self.config.image_size // self.config.patch_size) ** 2
config.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from minigpt4_video.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ print("--------------")
72
+ print("model arch",model.arch)
73
+ print("model cls",model_cls)
74
+
75
+ model_config_path = model_cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]
76
+
77
+ model_config = OmegaConf.create()
78
+ # hierarchy override, customized config > default config
79
+ model_config = OmegaConf.merge(
80
+ model_config,
81
+ OmegaConf.load(model_config_path),
82
+ {"model": config["model"]},
83
+ )
84
+
85
+ return model_config
86
+
87
+ @staticmethod
88
+ def build_runner_config(config):
89
+ return {"run": config.run}
90
+
91
+ @staticmethod
92
+ def build_dataset_config(config):
93
+ datasets = config.get("datasets", None)
94
+ if datasets is None:
95
+ raise KeyError(
96
+ "Expecting 'datasets' as the root key for dataset configuration."
97
+ )
98
+
99
+ dataset_config = OmegaConf.create()
100
+
101
+ for dataset_name in datasets:
102
+
103
+ print("dataset name", dataset_name)
104
+ builder_cls = registry.get_builder_class(dataset_name)
105
+
106
+ dataset_config_type = datasets[dataset_name].get("type", "default")
107
+ dataset_config_path = builder_cls.default_config_path(
108
+ type=dataset_config_type
109
+ )
110
+
111
+ # hierarchy override, customized config > default config
112
+ dataset_config = OmegaConf.merge(
113
+ dataset_config,
114
+ OmegaConf.load(dataset_config_path),
115
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
116
+ )
117
+
118
+ return dataset_config
119
+
120
+ def _convert_to_dot_list(self, opts):
121
+ if opts is None:
122
+ opts = []
123
+
124
+ if len(opts) == 0:
125
+ return opts
126
+
127
+ has_equal = opts[0].find("=") != -1
128
+
129
+ if has_equal:
130
+ return opts
131
+
132
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
133
+
134
+ def get_config(self):
135
+ return self.config
136
+
137
+ @property
138
+ def run_cfg(self):
139
+ return self.config.run
140
+
141
+ @property
142
+ def datasets_cfg(self):
143
+ return self.config.datasets
144
+
145
+ @property
146
+ def model_cfg(self):
147
+ return self.config.model
148
+
149
+ def pretty_print(self):
150
+ logging.info("\n===== Running Parameters =====")
151
+ logging.info(self._convert_node_to_json(self.config.run))
152
+
153
+ logging.info("\n====== Dataset Attributes ======")
154
+ datasets = self.config.datasets
155
+
156
+ for dataset in datasets:
157
+ if dataset in self.config.datasets:
158
+ logging.info(f"\n======== {dataset} =======")
159
+ dataset_config = self.config.datasets[dataset]
160
+ logging.info(self._convert_node_to_json(dataset_config))
161
+ else:
162
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
163
+
164
+ logging.info(f"\n====== Model Attributes ======")
165
+ logging.info(self._convert_node_to_json(self.config.model))
166
+
167
+ def _convert_node_to_json(self, node):
168
+ container = OmegaConf.to_container(node, resolve=True)
169
+ return json.dumps(container, indent=4, sort_keys=True)
170
+
171
+ def to_dict(self):
172
+ return OmegaConf.to_container(self.config)
173
+
174
+
175
+ def node_to_dict(node):
176
+ return OmegaConf.to_container(node)
177
+
178
+
179
+ class ConfigValidator:
180
+ """
181
+ This is a preliminary implementation to centralize and validate the configuration.
182
+ May be altered in the future.
183
+
184
+ A helper class to validate configurations from yaml file.
185
+
186
+ This serves the following purposes:
187
+ 1. Ensure all the options in the yaml are defined, raise error if not.
188
+ 2. when type mismatches are found, the validator will raise an error.
189
+ 3. a central place to store and display helpful messages for supported configurations.
190
+
191
+ """
192
+
193
+ class _Argument:
194
+ def __init__(self, name, choices=None, type=None, help=None):
195
+ self.name = name
196
+ self.val = None
197
+ self.choices = choices
198
+ self.type = type
199
+ self.help = help
200
+
201
+ def __str__(self):
202
+ s = f"{self.name}={self.val}"
203
+ if self.type is not None:
204
+ s += f", ({self.type})"
205
+ if self.choices is not None:
206
+ s += f", choices: {self.choices}"
207
+ if self.help is not None:
208
+ s += f", ({self.help})"
209
+ return s
210
+
211
+ def __init__(self, description):
212
+ self.description = description
213
+
214
+ self.arguments = dict()
215
+
216
+ self.parsed_args = None
217
+
218
+ def __getitem__(self, key):
219
+ assert self.parsed_args is not None, "No arguments parsed yet."
220
+
221
+ return self.parsed_args[key]
222
+
223
+ def __str__(self) -> str:
224
+ return self.format_help()
225
+
226
+ def add_argument(self, *args, **kwargs):
227
+ """
228
+ Assume the first argument is the name of the argument.
229
+ """
230
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
231
+
232
+ def validate(self, config=None):
233
+ """
234
+ Convert yaml config (dict-like) to list, required by argparse.
235
+ """
236
+ for k, v in config.items():
237
+ assert (
238
+ k in self.arguments
239
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
240
+
241
+ if self.arguments[k].type is not None:
242
+ try:
243
+ self.arguments[k].val = self.arguments[k].type(v)
244
+ except ValueError:
245
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
246
+
247
+ if self.arguments[k].choices is not None:
248
+ assert (
249
+ v in self.arguments[k].choices
250
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
251
+
252
+ return config
253
+
254
+ def format_arguments(self):
255
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
256
+
257
+ def format_help(self):
258
+ # description + key-value pair string for each argument
259
+ help_msg = str(self.description)
260
+ return help_msg + ", available arguments: " + self.format_arguments()
261
+
262
+ def print_help(self):
263
+ # display help message
264
+ print(self.format_help())
265
+
266
+
267
+ def create_runner_config_validator():
268
+ validator = ConfigValidator(description="Runner configurations")
269
+
270
+ validator.add_argument(
271
+ "runner",
272
+ type=str,
273
+ choices=["runner_base", "runner_iter"],
274
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
275
+ runner runs based on iters. Default: runner_base""",
276
+ )
277
+ # add argumetns for training dataset ratios
278
+ validator.add_argument(
279
+ "train_dataset_ratios",
280
+ type=Dict[str, float],
281
+ help="""Ratios of training dataset. This is used in iteration-based runner.
282
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
283
+ Default: None""",
284
+ )
285
+ validator.add_argument(
286
+ "max_iters",
287
+ type=float,
288
+ help="Maximum number of iterations to run.",
289
+ )
290
+ validator.add_argument(
291
+ "max_epoch",
292
+ type=int,
293
+ help="Maximum number of epochs to run.",
294
+ )
295
+ # add arguments for iters_per_inner_epoch
296
+ validator.add_argument(
297
+ "iters_per_inner_epoch",
298
+ type=float,
299
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
300
+ )
301
+ lr_scheds_choices = registry.list_lr_schedulers()
302
+ validator.add_argument(
303
+ "lr_sched",
304
+ type=str,
305
+ choices=lr_scheds_choices,
306
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
307
+ )
308
+ task_choices = registry.list_tasks()
309
+ validator.add_argument(
310
+ "task",
311
+ type=str,
312
+ choices=task_choices,
313
+ help="Task to use, from {}".format(task_choices),
314
+ )
315
+ # add arguments for init_lr
316
+ validator.add_argument(
317
+ "init_lr",
318
+ type=float,
319
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
320
+ )
321
+ # add arguments for min_lr
322
+ validator.add_argument(
323
+ "min_lr",
324
+ type=float,
325
+ help="Minimum learning rate (after decay).",
326
+ )
327
+ # add arguments for warmup_lr
328
+ validator.add_argument(
329
+ "warmup_lr",
330
+ type=float,
331
+ help="Starting learning rate for warmup.",
332
+ )
333
+ # add arguments for learning rate decay rate
334
+ validator.add_argument(
335
+ "lr_decay_rate",
336
+ type=float,
337
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
338
+ )
339
+ # add arguments for weight decay
340
+ validator.add_argument(
341
+ "weight_decay",
342
+ type=float,
343
+ help="Weight decay rate.",
344
+ )
345
+ # add arguments for training batch size
346
+ validator.add_argument(
347
+ "batch_size_train",
348
+ type=int,
349
+ help="Training batch size.",
350
+ )
351
+ # add arguments for evaluation batch size
352
+ validator.add_argument(
353
+ "batch_size_eval",
354
+ type=int,
355
+ help="Evaluation batch size, including validation and testing.",
356
+ )
357
+ # add arguments for number of workers for data loading
358
+ validator.add_argument(
359
+ "num_workers",
360
+ help="Number of workers for data loading.",
361
+ )
362
+ # add arguments for warm up steps
363
+ validator.add_argument(
364
+ "warmup_steps",
365
+ type=int,
366
+ help="Number of warmup steps. Required if a warmup schedule is used.",
367
+ )
368
+ # add arguments for random seed
369
+ validator.add_argument(
370
+ "seed",
371
+ type=int,
372
+ help="Random seed.",
373
+ )
374
+ # add arguments for output directory
375
+ validator.add_argument(
376
+ "output_dir",
377
+ type=str,
378
+ help="Output directory to save checkpoints and logs.",
379
+ )
380
+ # add arguments for whether only use evaluation
381
+ validator.add_argument(
382
+ "evaluate",
383
+ help="Whether to only evaluate the model. If true, training will not be performed.",
384
+ )
385
+ # add arguments for splits used for training, e.g. ["train", "val"]
386
+ validator.add_argument(
387
+ "train_splits",
388
+ type=list,
389
+ help="Splits to use for training.",
390
+ )
391
+ # add arguments for splits used for validation, e.g. ["val"]
392
+ validator.add_argument(
393
+ "valid_splits",
394
+ type=list,
395
+ help="Splits to use for validation. If not provided, will skip the validation.",
396
+ )
397
+ # add arguments for splits used for testing, e.g. ["test"]
398
+ validator.add_argument(
399
+ "test_splits",
400
+ type=list,
401
+ help="Splits to use for testing. If not provided, will skip the testing.",
402
+ )
403
+ # add arguments for accumulating gradient for iterations
404
+ validator.add_argument(
405
+ "accum_grad_iters",
406
+ type=int,
407
+ help="Number of iterations to accumulate gradient for.",
408
+ )
409
+
410
+ # ====== distributed training ======
411
+ validator.add_argument(
412
+ "device",
413
+ type=str,
414
+ choices=["cpu", "cuda"],
415
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
416
+ )
417
+ validator.add_argument(
418
+ "world_size",
419
+ type=int,
420
+ help="Number of processes participating in the job.",
421
+ )
422
+ validator.add_argument("dist_url", type=str)
423
+ validator.add_argument("distributed", type=bool)
424
+ # add arguments to opt using distributed sampler during evaluation or not
425
+ validator.add_argument(
426
+ "use_dist_eval_sampler",
427
+ type=bool,
428
+ help="Whether to use distributed sampler during evaluation or not.",
429
+ )
430
+
431
+ # ====== task specific ======
432
+ # generation task specific arguments
433
+ # add arguments for maximal length of text output
434
+ validator.add_argument(
435
+ "max_len",
436
+ type=int,
437
+ help="Maximal length of text output.",
438
+ )
439
+ # add arguments for minimal length of text output
440
+ validator.add_argument(
441
+ "min_len",
442
+ type=int,
443
+ help="Minimal length of text output.",
444
+ )
445
+ # add arguments number of beams
446
+ validator.add_argument(
447
+ "num_beams",
448
+ type=int,
449
+ help="Number of beams used for beam search.",
450
+ )
451
+
452
+ # vqa task specific arguments
453
+ # add arguments for number of answer candidates
454
+ validator.add_argument(
455
+ "num_ans_candidates",
456
+ type=int,
457
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
458
+ )
459
+ # add arguments for inference method
460
+ validator.add_argument(
461
+ "inference_method",
462
+ type=str,
463
+ choices=["genearte", "rank"],
464
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
465
+ )
466
+
467
+ # ====== model specific ======
468
+ validator.add_argument(
469
+ "k_test",
470
+ type=int,
471
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
472
+ )
473
+
474
+ return validator
conversation.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from minigpt4_video.registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "<s>"
32
+ sep2: str = "</s>"
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ # ret = self.system + self.sep
40
+ ret = self.system +"<s>"
41
+ for role, message in self.messages:
42
+ if message:
43
+ # ret += role + ": " + message + self.sep
44
+ ret+= role + message
45
+ # ret+= role + message
46
+ else:
47
+ # ret += role + ":"
48
+ # ret += self.sep2 + role
49
+ ret += role
50
+ return ret
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ # ret = self.system + seps[0]
54
+ ret = self.system+"<s>"
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ # ret += role + ": " + message + seps[i % 2]
58
+ ret += role+message+seps[i%2]
59
+ else:
60
+ # ret += role + ":"
61
+ ret += role
62
+ return ret
63
+ else:
64
+ raise ValueError(f"Invalid style: {self.sep_style}")
65
+
66
+ def append_message(self, role, message):
67
+ self.messages.append([role, message])
68
+
69
+ def to_gradio_chatbot(self):
70
+ ret = []
71
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
72
+ if i % 2 == 0:
73
+ ret.append([msg, None])
74
+ else:
75
+ ret[-1][-1] = msg
76
+ return ret
77
+
78
+ def copy(self):
79
+ return Conversation(
80
+ system=self.system,
81
+ # system_img=self.system_img,
82
+ roles=self.roles,
83
+ messages=[[x, y] for x, y in self.messages],
84
+ offset=self.offset,
85
+ sep_style=self.sep_style,
86
+ sep=self.sep,
87
+ sep2=self.sep2,
88
+ conv_id=self.conv_id)
89
+
90
+ def dict(self):
91
+ return {
92
+ "system": self.system,
93
+ # "system_img": self.system_img,
94
+ "roles": self.roles,
95
+ "messages": self.messages,
96
+ "offset": self.offset,
97
+ "sep": self.sep,
98
+ "sep2": self.sep2,
99
+ "conv_id": self.conv_id,
100
+ }
101
+
102
+
103
+ class StoppingCriteriaSub(StoppingCriteria):
104
+
105
+ def __init__(self, stops=[], encounters=1):
106
+ super().__init__()
107
+ self.stops = stops
108
+
109
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
110
+ for stop in self.stops:
111
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
112
+ return True
113
+
114
+ return False
115
+
116
+
117
+ CONV_VISION = Conversation(
118
+ # system="Give the following image: <Img>ImageContent</Img>. "
119
+ # "You will be able to see the image once I provide it to you. Please answer my questions.",
120
+ system = "",
121
+ roles = (r"[INST] ",r" [/INST]"),
122
+ messages=[],
123
+ offset=2,
124
+ sep_style=SeparatorStyle.SINGLE,
125
+ sep="<s>",
126
+ )
127
+
128
+
129
+ class Chat:
130
+ def __init__(self, model, vis_processor, device='cuda:0'):
131
+ self.device = device
132
+ self.model = model
133
+ self.vis_processor = vis_processor
134
+
135
+ self.conv = CONV_VISION.copy()
136
+ self.img_list = []
137
+ self.raw_answers = []
138
+
139
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
140
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
141
+
142
+ def reset(self):
143
+ self.conv.messages = []
144
+ self.img_list = []
145
+ # self.img_list = [img for img in self.conv.system_img]
146
+ self.raw_answers = []
147
+
148
+ def ask(self, text, conv):
149
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
150
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
151
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
152
+ else:
153
+ conv.append_message(conv.roles[0], text)
154
+
155
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
156
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
157
+ conv.append_message(conv.roles[1], None)
158
+ embs = self.get_context_emb(conv, img_list)
159
+
160
+ current_max_len = embs.shape[1] + max_new_tokens
161
+ if current_max_len - max_length > 0:
162
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
163
+ 'The model will not see the contexts outside the range.')
164
+ begin_idx = max(0, current_max_len - max_length)
165
+
166
+ embs = embs[:, begin_idx:]
167
+
168
+ outputs = self.model.llama_model.generate(
169
+ inputs_embeds=embs,
170
+ max_new_tokens=max_new_tokens,
171
+ stopping_criteria=self.stopping_criteria,
172
+ num_beams=num_beams,
173
+ min_length=min_length,
174
+ top_p=top_p,
175
+ repetition_penalty=repetition_penalty,
176
+ length_penalty=length_penalty,
177
+ temperature=temperature,
178
+ do_sample=False,
179
+ )
180
+ output_token = outputs[0]
181
+ if output_token[0] == 0:
182
+ output_token = output_token[1:]
183
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
184
+ self.raw_answers.append(output_text)
185
+ output_text = output_text.split('</s>')[0] # remove the stop sign '###'
186
+ output_text = output_text.replace("<s>", "")
187
+ output_text = output_text.split(r'[/INST]')[-1].strip()
188
+ self.conv.messages[-1][1] = output_text
189
+ return output_text, output_token.cpu().numpy()
190
+
191
+ def upload_img(self, image):
192
+ if isinstance(image, str): # is a image path
193
+ raw_image = Image.open(image).convert('RGB')
194
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
195
+ elif isinstance(image, Image.Image):
196
+ raw_image = image
197
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
198
+ elif isinstance(image, torch.Tensor):
199
+ if len(image.shape) == 3:
200
+ image = image.unsqueeze(0)
201
+ image = image.to(self.device)
202
+
203
+ image_emb, _ = self.model.encode_img(image)
204
+ self.img_list.append(image_emb)
205
+ self.conv.append_message(self.conv.roles[0], "<Img><ImageHere></Img>")
206
+ msg = "Received."
207
+ # self.conv.append_message(self.conv.roles[1], msg)
208
+ return msg
209
+
210
+ def get_context_emb(self, conv, img_list):
211
+ prompt = conv.get_prompt()
212
+ prompt_segs = prompt.split('<ImageHere>')
213
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
214
+ seg_tokens = [
215
+ self.model.llama_tokenizer(
216
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
217
+ # only add bos to the first seg
218
+ for i, seg in enumerate(prompt_segs)
219
+ ]
220
+
221
+ seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
222
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
223
+ mixed_embs = torch.cat(mixed_embs, dim=1)
224
+ return mixed_embs
dist_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if args.distributed is False:
59
+ print("Not using distributed mode")
60
+ args.rank = 0
61
+ return
62
+
63
+ if 'LOCAL_RANK' not in os.environ:
64
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
65
+
66
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
67
+ args.rank = int(os.environ["RANK"])
68
+ args.world_size = int(os.environ["WORLD_SIZE"])
69
+ args.gpu = int(os.environ["LOCAL_RANK"])
70
+ elif "SLURM_PROCID" in os.environ:
71
+ args.rank = int(os.environ["SLURM_PROCID"])
72
+ args.gpu = args.rank % torch.cuda.device_count()
73
+ else:
74
+ print("Not using distributed mode")
75
+ args.distributed = False
76
+ args.rank = 0
77
+ return
78
+
79
+ args.distributed = True
80
+
81
+ torch.cuda.set_device(args.gpu)
82
+ args.dist_backend = "nccl"
83
+ print(
84
+ "| distributed init (rank {}, world {}): {}".format(
85
+ args.rank, args.world_size, args.dist_url
86
+ ),
87
+ flush=True,
88
+ )
89
+ torch.distributed.init_process_group(
90
+ backend=args.dist_backend,
91
+ init_method=args.dist_url,
92
+ world_size=args.world_size,
93
+ rank=args.rank,
94
+ timeout=datetime.timedelta(
95
+ days=365
96
+ ), # allow auto-downloading and de-compressing
97
+ )
98
+ torch.distributed.barrier()
99
+ setup_for_distributed(args.rank == 0)
100
+
101
+
102
+ def get_dist_info():
103
+ if torch.__version__ < "1.0":
104
+ initialized = dist._initialized
105
+ else:
106
+ initialized = dist.is_initialized()
107
+ if initialized:
108
+ rank = dist.get_rank()
109
+ world_size = dist.get_world_size()
110
+ else: # non-distributed training
111
+ rank = 0
112
+ world_size = 1
113
+ return rank, world_size
114
+
115
+
116
+ def main_process(func):
117
+ @functools.wraps(func)
118
+ def wrapper(*args, **kwargs):
119
+ rank, _ = get_dist_info()
120
+ if rank == 0:
121
+ return func(*args, **kwargs)
122
+
123
+ return wrapper
124
+
125
+
126
+ def download_cached_file(url, check_hash=True, progress=False):
127
+ """
128
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
129
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
130
+ """
131
+
132
+ def get_cached_file_path():
133
+ # a hack to sync the file path across processes
134
+ parts = torch.hub.urlparse(url)
135
+ filename = os.path.basename(parts.path)
136
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
137
+
138
+ return cached_file
139
+
140
+ if is_main_process():
141
+ timm_hub.download_cached_file(url, check_hash, progress)
142
+
143
+ if is_dist_avail_and_initialized():
144
+ dist.barrier()
145
+
146
+ return get_cached_file_path()
eva_vit.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from minigpt4_video.dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the orignal BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ # depth = 37,
423
+ num_heads=1408//88,
424
+ mlp_ratio=4.3637,
425
+ qkv_bias=True,
426
+ drop_path_rate=drop_path_rate,
427
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
428
+ use_checkpoint=use_checkpoint,
429
+ )
430
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
431
+ cached_file = download_cached_file(
432
+ url, check_hash=False, progress=True
433
+ )
434
+ state_dict = torch.load(cached_file, map_location="cpu")
435
+ interpolate_pos_embed(model,state_dict)
436
+
437
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
438
+ # print(incompatible_keys)
439
+
440
+ if precision == "fp16":
441
+ # model.to("cuda")
442
+ convert_weights_to_fp16(model)
443
+ return model
gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from minigpt4_video import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
mini_gpt4v.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import torch
5
+ from torch.cuda.amp import autocast as autocast
6
+ import torch.nn as nn
7
+
8
+ from minigpt4.common.registry import registry
9
+ from minigpt4.models.blip2 import Blip2Base, disabled_train
10
+ from minigpt4.models.modeling_llama_v2 import LlamaForCausalLM
11
+ from minigpt4.conversation.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
12
+
13
+ from transformers import LlamaTokenizer, CodeLlamaTokenizer, BitsAndBytesConfig
14
+
15
+ from peft import (
16
+ LoraConfig,
17
+ get_peft_model,
18
+ prepare_model_for_kbit_training
19
+ )
20
+ import time
21
+ import numpy as np
22
+
23
+ from minigpt4.models import policies
24
+
25
+
26
+ @registry.register_model("mini_gpt4v")
27
+ class MiniGPT4v(Blip2Base):
28
+ """
29
+ BLIP2 GPT-LLAMA model.
30
+ """
31
+
32
+ PRETRAINED_MODEL_CONFIG_DICT = {
33
+ "pretrain_vicuna": "configs/models/minigpt4.yaml",
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ vit_model="eva_clip_g",
39
+ img_size=224,
40
+ drop_path_rate=0,
41
+ use_grad_checkpoint=False,
42
+ vit_precision="fp16",
43
+ freeze_vit=True,
44
+ llama_model="",
45
+ prompt_path="",
46
+ prompt_template="",
47
+ max_txt_len=32,
48
+ low_resource=False, # use 8 bit and put vit in cpu
49
+ end_sym='\n',
50
+ lora_r = 8,
51
+ lora_target_modules = ["q_proj","v_proj"],
52
+ lora_alpha=16,
53
+ # lora_r = 16,
54
+ # lora_target_modules = ["q_proj","v_proj","v_proj"],
55
+ lora_dropout= 0.05,
56
+ ckpt_path = "",
57
+ system_prompt= False,
58
+ chat_template=False,
59
+ token_pooling=True,
60
+ use_grad_checkpoint_llm=False,
61
+ max_context_len=3800,
62
+ remove_template = False,
63
+
64
+ ):
65
+ super().__init__()
66
+
67
+ self.tokenizer = self.init_tokenizer()
68
+ self.low_resource = low_resource
69
+ self.token_pooling = token_pooling
70
+ self.remove_template = remove_template
71
+
72
+ print("token pooling", self.token_pooling)
73
+
74
+
75
+ self.use_grad_checkpoint_llm = use_grad_checkpoint_llm
76
+ self.max_context_len = max_context_len
77
+ self.chat_template = chat_template
78
+
79
+ # print('Loading VIT')
80
+ # self.visual_encoder, self.ln_vision = self.init_vision_encoder(
81
+ # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
82
+ # )
83
+
84
+
85
+ print("vit precision", vit_precision)
86
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
87
+ vit_model, 224, drop_path_rate, use_grad_checkpoint, vit_precision
88
+ )
89
+ for name, param in self.visual_encoder.named_parameters():
90
+ param.requires_grad = False
91
+ self.visual_encoder = self.visual_encoder.eval()
92
+ self.visual_encoder.train = disabled_train
93
+ for name, param in self.ln_vision.named_parameters():
94
+ param.requires_grad = False
95
+ self.ln_vision = self.ln_vision.eval()
96
+ self.ln_vision.train = disabled_train
97
+ logging.info("freeze vision encoder")
98
+ print("freeze the vision encoder")
99
+
100
+
101
+ print('Loading VIT Done')
102
+
103
+ # print("visual encoder shape", self.visual_encoder.pos_embed.shape)
104
+ # assert False
105
+
106
+ print('Loading LLAMA')
107
+
108
+
109
+ self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
110
+
111
+ if 'CodeLlama' in llama_model:
112
+ self.llama_tokenizer = CodeLlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
113
+ self.llama_tokenizer.pad_token = "$$"
114
+ else:
115
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) #
116
+ self.llama_tokenizer.pad_token = "$$"
117
+
118
+ self.system_prompt = system_prompt
119
+
120
+ bnb_config = BitsAndBytesConfig(
121
+ load_in_4bit=True,
122
+ bnb_4bit_use_double_quant=True,
123
+ bnb_4bit_quant_type="nf4",
124
+ bnb_4bit_compute_dtype=torch.bfloat16
125
+ )
126
+
127
+
128
+
129
+ self.llama_model = LlamaForCausalLM.from_pretrained(
130
+ llama_model,
131
+ quantization_config=bnb_config,
132
+ device_map={"": 0}
133
+ )
134
+
135
+ # self.llama_model.gradient_checkpointing_enable()
136
+ self.llama_model = prepare_model_for_kbit_training(self.llama_model)
137
+
138
+ # self.llama_model.print_trainable_parameters()
139
+
140
+
141
+ print('Loading LLAMA Done')
142
+
143
+ self.merge_n = 3
144
+
145
+ self.llama_proj = nn.Linear(
146
+ 1408 * self.merge_n**2, self.llama_model.config.hidden_size
147
+ )
148
+
149
+ self.max_txt_len = max_txt_len
150
+ self.end_sym = end_sym
151
+
152
+ if prompt_path:
153
+ with open(prompt_path, 'r') as f:
154
+ raw_prompts = f.read().splitlines()
155
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
156
+ self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
157
+ print('Load {} training prompts'.format(len(self.prompt_list)))
158
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
159
+ else:
160
+ self.prompt_list = []
161
+
162
+ def encode_img(self, image):
163
+ device = image.device
164
+ if len(image.shape) > 4:
165
+ image = image.reshape(-1, *image.shape[-3:])
166
+
167
+ bs, ch, w, h = image.shape
168
+ assert w % 224 == 0
169
+ bw = w // 224
170
+ assert h % 224 == 0
171
+ bh = h // 224
172
+ image_patches = image.view(bs, ch, bw, 224, bh, 224).permute(0, 2, 4, 1, 3, 5) # bs, bw, bh, ch, 224, 224
173
+ image_patches = image_patches.reshape(bs * bw * bh, ch, 224, 224)
174
+
175
+ with self.maybe_autocast():
176
+ image_patch_embeds = self.ln_vision(self.visual_encoder(image_patches)).to(device)
177
+
178
+ image_patch_embeds = image_patch_embeds[:,1:,:].reshape(bs, bw, bh, 16, 16, image_patch_embeds.shape[-1])
179
+ image_patch_embeds = image_patch_embeds.permute(0, 1, 3, 2, 4, 5) # bs, bw, 16, bh, 16, hs
180
+ image_embeds = image_patch_embeds.reshape(bs, bw * 16 * bh * 16, image_patch_embeds.shape[-1])
181
+
182
+ bs, pn, hs = image_embeds.shape
183
+
184
+ image_embeds = image_embeds.view(bs, int(pn/self.merge_n**2), int(hs*self.merge_n**2))
185
+
186
+ inputs_llama = self.llama_proj(image_embeds)
187
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
188
+ return inputs_llama, atts_llama
189
+
190
+ def get_context_emb(self, prompt, img_list):
191
+ img_device = img_list[0].device
192
+ prompt_segs = prompt.split('<ImageHere>')
193
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
194
+ seg_tokens = [
195
+ self.llama_tokenizer(
196
+ seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
197
+ for i, seg in enumerate(prompt_segs)
198
+ ]
199
+
200
+ seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
201
+
202
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
203
+
204
+ mixed_embs = torch.cat(mixed_embs, dim=1)
205
+ return mixed_embs
206
+
207
+ def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
208
+ if prompts is None or len(prompts) == 0:
209
+ # prompts is not provided, just return the original image embedding
210
+ return img_embeds, atts_img
211
+ elif img_embeds is None:
212
+ # prompt is provided but there is no image embedding. return the prompt embedding in right padding
213
+ self.llama_tokenizer.padding_side = "right"
214
+ prompt_tokens = self.llama_tokenizer(
215
+ prompts,
216
+ return_tensors="pt",
217
+ padding="longest",
218
+ add_special_tokens=False
219
+ ).to(self.device)
220
+ prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
221
+ atts_prompt = prompt_tokens.attention_mask
222
+ return prompt_embeds, atts_prompt
223
+
224
+ else:
225
+ # return the multi-modal embedding in right padding
226
+ emb_lists = []
227
+
228
+ for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
229
+ pn = each_img_embed.shape[-2]
230
+ if lengths is not None:
231
+ each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
232
+ each_img_embed = each_img_embed[:lengths[idx] * pn]
233
+
234
+ p_segs = each_prompt.split('<ImageHere>')
235
+ interleave_emb = []
236
+ for idx, seg in enumerate(p_segs[:-1]):
237
+ p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
238
+ p_embed = self.embed_tokens(p_tokens.input_ids)
239
+ interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
240
+
241
+ wrapped_emb = torch.cat(interleave_emb, dim=1)
242
+ p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
243
+ p_embed = self.embed_tokens(p_tokens.input_ids)
244
+ wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
245
+ emb_lists.append(wrapped_emb)
246
+
247
+ emb_lens = [emb.shape[1] for emb in emb_lists]
248
+ pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
249
+
250
+ max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
251
+ wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
252
+ wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
253
+
254
+ for i, emb in enumerate(emb_lists):
255
+ length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
256
+ wrapped_embs[i, :length] = emb[:, :length]
257
+ wrapped_atts[i, :length] = 1
258
+
259
+ return wrapped_embs, wrapped_atts
260
+
261
+ def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
262
+ """
263
+ Concatenate the batched input embedding and batched output embedding together.
264
+ Both the input and the output embedding should be right padded.
265
+ """
266
+
267
+ input_lens = []
268
+ cat_embs = []
269
+ cat_atts = []
270
+
271
+ for i in range(input_embs.size(0)):
272
+ input_len = input_atts[i].sum()
273
+ input_lens.append(input_len)
274
+
275
+ cat_embs.append(
276
+ torch.cat([
277
+ input_embs[i][:input_len],
278
+ output_embs[i],
279
+ input_embs[i][input_len:]
280
+ ])
281
+ )
282
+ cat_atts.append(
283
+ torch.cat([
284
+ input_atts[i][:input_len],
285
+ output_atts[i],
286
+ input_atts[i][input_len:]
287
+ ])
288
+ )
289
+ # print('===================================')
290
+ # print('check input emb: ', input_embs[i][this_input_ones-2:this_input_ones])
291
+ # print('check pad emb: ', input_embs[i][this_input_ones:this_input_ones+2])
292
+ # print('check out emb: ', output_embs[i][:2])
293
+ # print('check out pad emb: ', output_embs[i][-2:])
294
+ # print('+++++++++++++++++++++++++++++++++++')
295
+ #
296
+ # print('check attn before: ', input_atts[i][:this_input_ones])
297
+ # print('check attn after: ', input_atts[i][this_input_ones:])
298
+ # print('check attn gt before: ', output_atts[i][:3])
299
+ # print('check attn gt after: ', output_atts[i][-3:])
300
+
301
+ cat_embs = torch.stack(cat_embs)
302
+ cat_atts = torch.stack(cat_atts)
303
+ return cat_embs, cat_atts, input_lens
304
+
305
+ def get_conv_emb(self, conv_q, conv_a, conv_img):
306
+ """concatenate conversation and make sure the model is only trained to regress the answer"""
307
+
308
+ regress_embs_list = []
309
+ targets_list = []
310
+
311
+ batch_size = len(conv_q)
312
+ for batch_idx in range(batch_size):
313
+ questions, answers = conv_q[batch_idx], conv_a[batch_idx]
314
+ assigned_imgs = conv_img[batch_idx]
315
+ questions = [self.prompt_wrap(
316
+ img_embeds=img,
317
+ atts_img=None,
318
+ prompts=[q],
319
+ lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
320
+ q_embs = [emb for emb, _ in questions]
321
+
322
+ answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
323
+ cur_emb = []
324
+ cur_target = []
325
+ for i in range(len(questions)):
326
+ cur_emb.append(q_embs[i])
327
+ cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
328
+
329
+ cur_emb.append(self.embed_tokens(answers[i].input_ids))
330
+ cur_target.append(answers[i].input_ids)
331
+
332
+ cur_emb = torch.cat(cur_emb, dim=1)
333
+ cur_target = torch.cat(cur_target, dim=1)
334
+
335
+ regress_embs_list.append(cur_emb)
336
+ targets_list.append(cur_target)
337
+
338
+ max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
339
+
340
+ regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
341
+ regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
342
+ targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
343
+
344
+ for batch_idx in range(batch_size):
345
+ cur_len = regress_embs_list[batch_idx].shape[1]
346
+ regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
347
+ regress_attn[batch_idx, :cur_len] = 1
348
+ targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
349
+
350
+ return regress_embeds, regress_attn, targets
351
+
352
+ def preparing_embedding(self, samples):
353
+ def remove_special_tokens(data):
354
+
355
+ # if "instruction_input" in data:
356
+ data = [instruct.replace(" [caption]","") for instruct in data]
357
+ data = [instruct.replace(" [vqa]","") for instruct in data]
358
+ data = [instruct.replace(" [grounding]","") for instruct in data]
359
+ data = [instruct.replace(" [identify]","") for instruct in data]
360
+ data = [instruct.replace(" [refer]","") for instruct in data]
361
+ return data
362
+
363
+ ### prepare input tokens
364
+ if 'image' in samples:
365
+ img_embeds, img_atts = self.encode_img(samples["image"])
366
+ else:
367
+ img_embeds = img_atts = None
368
+
369
+ if 'conv_q' in samples:
370
+ # handeling conversation datasets
371
+ conv_q, conv_a = samples['conv_q'], samples['conv_a']
372
+
373
+ connect_sym = samples['connect_sym'][0]
374
+ conv_q = [q.split(connect_sym)for q in conv_q]
375
+ conv_a = [a.split(connect_sym) for a in conv_a]
376
+ conv_img = assign_imgs(conv_q, img_embeds)
377
+
378
+ if self.chat_template:
379
+ conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
380
+
381
+ regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
382
+ cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
383
+
384
+ else:
385
+ instruction = samples["instruction_input"] if "instruction_input" in samples else None
386
+
387
+ # print("instruction before", instruction)
388
+ if self.remove_template:
389
+ instruction = remove_special_tokens(instruction)
390
+ # print("instruction after", instruction)
391
+
392
+ if self.chat_template:
393
+ instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
394
+
395
+ if 'length' in samples:
396
+ # the input is a image train (like videos)
397
+ bsz, pn, hs = img_embeds.shape
398
+ img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
399
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
400
+ else:
401
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
402
+
403
+ ### prepare target tokens
404
+ self.llama_tokenizer.padding_side = "right"
405
+ text = [t + self.end_sym for t in samples["answer"]]
406
+
407
+ regress_tokens = self.llama_tokenizer(
408
+ text,
409
+ return_tensors="pt",
410
+ padding="longest",
411
+ truncation=True,
412
+ max_length=self.max_txt_len,
413
+ add_special_tokens=False
414
+ ).to(self.device)
415
+
416
+ regress_token_ids = regress_tokens.input_ids
417
+ regress_atts = regress_tokens.attention_mask
418
+ part_targets = regress_token_ids.masked_fill(
419
+ regress_token_ids == self.llama_tokenizer.pad_token_id, -100
420
+ )
421
+
422
+ regress_embeds = self.embed_tokens(regress_token_ids)
423
+
424
+ return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
425
+
426
+ def forward(self, samples, reduction="mean"):
427
+ # prepare the embedding to condition and the embedding to regress
428
+ cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
429
+ self.preparing_embedding(samples)
430
+
431
+ # concat the embedding to condition and the embedding to regress
432
+ inputs_embeds, attention_mask, input_lens = \
433
+ self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
434
+
435
+ # get bos token embedding
436
+ bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
437
+ bos_embeds = self.embed_tokens(bos)
438
+ bos_atts = attention_mask[:, :1]
439
+
440
+ # add bos token at the begining
441
+ inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
442
+ attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
443
+
444
+ # ensemble the final targets
445
+ targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
446
+ dtype=torch.long).to(self.device).fill_(-100)
447
+ for i, target in enumerate(part_targets):
448
+ targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
449
+
450
+ with self.maybe_autocast():
451
+ outputs = self.llama_model(
452
+ inputs_embeds=inputs_embeds,
453
+ attention_mask=attention_mask,
454
+ return_dict=True,
455
+ labels=targets,
456
+ reduction=reduction
457
+ )
458
+ loss = outputs.loss
459
+
460
+ return {"loss": loss}
461
+
462
+ @torch.no_grad()
463
+ def generate(
464
+ self,
465
+ images,
466
+ texts,
467
+ use_nucleus_sampling=False,
468
+ num_beams=1,
469
+ max_new_tokens=20,
470
+ min_length=1,
471
+ top_p=0.9,
472
+ repetition_penalty=1,
473
+ length_penalty=1,
474
+ temperature=1,
475
+ do_sample=False,
476
+ stop_words_ids=[2],
477
+ lengths=None,
478
+ ):
479
+ '''
480
+ function for generate test use
481
+ '''
482
+
483
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
484
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
485
+
486
+ img_embeds, atts_img = self.encode_img(images.to(self.device))
487
+ if lengths is not None:
488
+ image_lists = []
489
+ img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
490
+ for idx, img_embed in enumerate(img_embeds):
491
+ image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
492
+ else:
493
+ image_lists = [[image_emb[None]] for image_emb in img_embeds]
494
+ assert len(texts) == len(image_lists)
495
+ batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
496
+
497
+ batch_size = len(batch_embs)
498
+ max_len = max([emb.shape[1] for emb in batch_embs])
499
+ emb_dim = batch_embs[0].shape[2]
500
+ dtype = batch_embs[0].dtype
501
+ device = batch_embs[0].device
502
+
503
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
504
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
505
+ for i, emb in enumerate(batch_embs):
506
+ emb_len = emb.shape[1]
507
+ embs[i, -emb_len:] = emb[0]
508
+ attn_mask[i, -emb_len:] = 1
509
+
510
+ with self.maybe_autocast():
511
+ outputs = self.llama_model.generate(
512
+ inputs_embeds=embs,
513
+ attention_mask=attn_mask,
514
+ max_new_tokens=max_new_tokens,
515
+ num_beams=num_beams,
516
+ do_sample=do_sample,
517
+ # stopping_criteria=stopping_criteria,
518
+ )
519
+
520
+ answers = []
521
+ for output_token in outputs:
522
+ if output_token[0] == 0:
523
+ output_token = output_token[1:]
524
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
525
+ output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
526
+ output_texts = output_texts.replace("<s>", "")
527
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
528
+ answers.append(output_texts)
529
+
530
+ return answers
531
+
532
+ @torch.no_grad()
533
+ def multi_select(self, images, texts, answers, num_cand=None):
534
+ all_losses = []
535
+ for answer in answers:
536
+ choice_samples = {
537
+ 'image': images,
538
+ 'instruction_input': texts,
539
+ 'answer': answer
540
+ }
541
+ loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
542
+ all_losses.append(loss)
543
+ torch.cuda.empty_cache()
544
+ all_losses = torch.cat(all_losses, dim=-1)
545
+ if num_cand is not None:
546
+ for i in range(all_losses.shape[0]):
547
+ all_losses[i, num_cand[i]:] = 9999
548
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
549
+ return output_class_ranks.tolist()
550
+
551
+ def predict_answers(
552
+ self,
553
+ samples,
554
+ num_beams=5,
555
+ inference_method="generate",
556
+ max_len=10,
557
+ min_len=1,
558
+ num_ans_candidates=128,
559
+ answer_list=None,
560
+ prompt="",
561
+ length_penalty=0,
562
+ **kwargs
563
+ ):
564
+ '''
565
+ function for open-ended VQA
566
+ '''
567
+ images = samples["image"].cuda()
568
+ texts = samples["instruction_input"]
569
+
570
+ output_text = self.generate(
571
+ images=images,
572
+ texts=texts,
573
+ num_beams=num_beams,
574
+ max_new_tokens=max_len,
575
+ min_length=min_len,
576
+ length_penalty=length_penalty
577
+ )
578
+
579
+ if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
580
+ output_text = self._lemmatize(output_text)
581
+
582
+ return output_text
583
+
584
+ def predict_class(
585
+ self,
586
+ samples,
587
+ num_beams=5,
588
+ inference_method="generate",
589
+ max_len=10,
590
+ min_len=1,
591
+ num_ans_candidates=5,
592
+ answer_list=None,
593
+ prompt="",
594
+ length_penalty=0,
595
+ **kwargs
596
+ ):
597
+ '''
598
+ function for multi-choice VQA
599
+ '''
600
+
601
+ image = samples["image"].cuda()
602
+ instruction = samples['instruction_input']
603
+ answers = samples["choices"]
604
+ num_cand = samples["num_choices"]
605
+
606
+ ranks = self.multi_select(image, instruction, answers, num_cand)
607
+
608
+ pred_ans = []
609
+ for i, rank in enumerate(ranks):
610
+ pred = answers[rank[0]][i]
611
+ pred_ans.append(pred)
612
+ return pred_ans
613
+
614
+ def embed_tokens(self, token_ids):
615
+ try:
616
+ embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
617
+ except AttributeError:
618
+ embeds = self.llama_model.model.embed_tokens(token_ids)
619
+
620
+ return embeds
621
+
622
+ @classmethod
623
+ def from_config(cls, cfg):
624
+ vit_model = cfg.get("vit_model", "eva_clip_g")
625
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
626
+ img_size = cfg.get("image_size")
627
+ num_query_token = cfg.get("num_query_token")
628
+ llama_model = cfg.get("llama_model")
629
+
630
+ drop_path_rate = cfg.get("drop_path_rate", 0)
631
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
632
+ vit_precision = cfg.get("vit_precision", "fp16")
633
+ freeze_vit = cfg.get("freeze_vit", True)
634
+ freeze_qformer = cfg.get("freeze_qformer", True)
635
+ low_resource = cfg.get("low_resource", False)
636
+
637
+ prompt_path = cfg.get("prompt_path", "")
638
+ prompt_template = cfg.get("prompt_template", "")
639
+ max_txt_len = cfg.get("max_txt_len", 300)
640
+ end_sym = cfg.get("end_sym", '\n')
641
+
642
+ lora_r = cfg.get("lora_r",64)
643
+ lora_alpha = cfg.get("lora_alpha",16)
644
+ chat_template = cfg.get("chat_template",False)
645
+ system_prompt = cfg.get("system_prompt", False)
646
+ token_pooling = cfg.get("token_pooling",True)
647
+
648
+ use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
649
+ max_context_len = cfg.get("max_context_len", 3800)
650
+ remove_template = cfg.get("remove_template", False)
651
+
652
+
653
+ model = cls(
654
+ vit_model=vit_model,
655
+ img_size=img_size,
656
+ drop_path_rate=drop_path_rate,
657
+ use_grad_checkpoint=use_grad_checkpoint,
658
+ vit_precision=vit_precision,
659
+ freeze_vit=freeze_vit,
660
+ llama_model=llama_model,
661
+ prompt_path=prompt_path,
662
+ prompt_template=prompt_template,
663
+ max_txt_len=max_txt_len,
664
+ low_resource=low_resource,
665
+ end_sym=end_sym,
666
+ lora_r = lora_r,
667
+ lora_alpha = lora_alpha,
668
+ chat_template = chat_template,
669
+ system_prompt = system_prompt,
670
+ token_pooling = token_pooling,
671
+ use_grad_checkpoint_llm=use_grad_checkpoint_llm,
672
+ max_context_len=max_context_len,
673
+ remove_template = remove_template
674
+ )
675
+
676
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
677
+ if ckpt_path:
678
+ print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
679
+ ckpt = torch.load(ckpt_path, map_location="cpu")
680
+ msg = model.load_state_dict(ckpt['model'], strict=False)
681
+
682
+ return model
683
+
684
+
685
+ def assign_imgs(batched_instruct_list, batched_img_embeds):
686
+ '''this function is used when the data is interleaved.
687
+ the interlevaed data is separated, and this function assign
688
+ corresponding image embeddings to each segment'''
689
+ if len(batched_img_embeds.shape) == 3:
690
+ batched_img_embeds = batched_img_embeds[:, None]
691
+
692
+ batched_assigned = []
693
+
694
+ for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
695
+ img_idx = 0
696
+ assigned_img = []
697
+ n_assigned = []
698
+ for instruct in instruct_list:
699
+ n_img = instruct.count('<ImageHere>')
700
+ if n_img > 0: # this instruction include images.
701
+ assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
702
+ img_idx += n_img
703
+ n_assigned.append(n_img)
704
+ else: # this instruction doesn't include images
705
+ assigned_img.append(None)
706
+ n_assigned.append(None)
707
+ batched_assigned.append(assigned_img)
708
+
709
+ return batched_assigned
mistral.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+
3
+ device = "cuda" # the device to load the model onto
4
+
5
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
6
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
7
+
8
+ messages = [
9
+ {"role": "user", "content": "What is your favourite condiment?"},
10
+ {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
11
+ {"role": "user", "content": "Do you have mayonnaise recipes?"}
12
+ ]
13
+ p="Well, I'm quite partial to a good squeeze of fresh lemon juice."
14
+ encoded_input = tokenizer(p, return_tensors='pt')
15
+ embeds = model.model.embed_tokens(encoded_input.input_ids)
16
+ print(embeds.shape)
17
+
18
+
19
+ encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
20
+ model_inputs = encodeds.to(device)
21
+ model.to(device)
22
+
23
+ generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
24
+ decoded = tokenizer.batch_decode(generated_ids)
25
+ print(decoded[0])
modeling_llama_v2.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
11
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
12
+ # from minigpt4_video.models.transformers.src.transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
13
+
14
+ class LlamaForCausalLM(LlamaForCausalLMOrig):
15
+
16
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
17
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
18
+ def forward(
19
+ self,
20
+ input_ids: torch.LongTensor = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.LongTensor] = None,
23
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
24
+ inputs_embeds: Optional[torch.FloatTensor] = None,
25
+ labels: Optional[torch.LongTensor] = None,
26
+ use_cache: Optional[bool] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ cache_position: Optional[torch.LongTensor] = None,
31
+ reduction: Optional[str] = "mean",
32
+ use_fastv: Optional[bool] = False,
33
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
34
+ r"""
35
+ Args:
36
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40
+
41
+ Returns:
42
+
43
+ Example:
44
+
45
+ ```python
46
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
47
+
48
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
49
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
50
+
51
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
52
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
53
+
54
+ >>> # Generate
55
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
56
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
57
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
58
+ ```"""
59
+
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ if use_fastv :
68
+ fastv_config = {
69
+ "use_fastv": True,
70
+ "fastv_k": 3,
71
+ "fastv_r": 0.75,
72
+ "image_token_start_index": 5,
73
+ "image_token_length": 576
74
+ }
75
+ print(f"Using fastv :{fastv_config}")
76
+ outputs = self.model.fastv_forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ use_cache=use_cache,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict,
86
+ fastv_config=fastv_config,
87
+ cache_position=cache_position,
88
+ )
89
+ else:
90
+ outputs = self.model(
91
+ input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ position_ids=position_ids,
94
+ past_key_values=past_key_values,
95
+ inputs_embeds=inputs_embeds,
96
+ use_cache=use_cache,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ return_dict=return_dict,
100
+ # cache_position=cache_position,
101
+ )
102
+
103
+ hidden_states = outputs[0]
104
+ if self.config.pretraining_tp > 1:
105
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
106
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
107
+ logits = torch.cat(logits, dim=-1)
108
+ else:
109
+ logits = self.lm_head(hidden_states)
110
+ logits = logits.float()
111
+
112
+ loss = None
113
+ if labels is not None:
114
+ # Shift so that tokens < n predict n
115
+ shift_logits = logits[..., :-1, :].contiguous()
116
+ shift_labels = labels[..., 1:].contiguous()
117
+ # Flatten the tokens
118
+ loss_fct = CrossEntropyLoss(reduction=reduction)
119
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
120
+ shift_labels = shift_labels.view(-1)
121
+ # Enable model parallelism
122
+ shift_labels = shift_labels.to(shift_logits.device)
123
+ loss = loss_fct(shift_logits, shift_labels)
124
+ if reduction == "none":
125
+ loss = loss.view(logits.size(0), -1).mean(1)
126
+
127
+ if not return_dict:
128
+ output = (logits,) + outputs[1:]
129
+ return (loss,) + output if loss is not None else output
130
+
131
+ return CausalLMOutputWithPast(
132
+ loss=loss,
133
+ logits=logits,
134
+ past_key_values=outputs.past_key_values,
135
+ hidden_states=outputs.hidden_states,
136
+ attentions=outputs.attentions,
137
+ )
modeling_mistral.py ADDED
@@ -0,0 +1,1388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mistral model."""
21
+ import inspect
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ is_flash_attn_greater_or_equal_2_10,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.models.mistral.configuration_mistral import MistralConfig
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CONFIG_FOR_DOC = "MistralConfig"
58
+
59
+
60
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
61
+ def _get_unpad_data(attention_mask):
62
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
64
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
65
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
66
+ return (
67
+ indices,
68
+ cu_seqlens,
69
+ max_seqlen_in_batch,
70
+ )
71
+
72
+
73
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
74
+ class MistralRMSNorm(nn.Module):
75
+ def __init__(self, hidden_size, eps=1e-6):
76
+ """
77
+ MistralRMSNorm is equivalent to T5LayerNorm
78
+ """
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.ones(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward(self, hidden_states):
84
+ input_dtype = hidden_states.dtype
85
+ hidden_states = hidden_states.to(torch.float32)
86
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88
+ return self.weight * hidden_states.to(input_dtype)
89
+
90
+
91
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
92
+ # TODO @Arthur no longer copied from LLama after static cache
93
+ class MistralRotaryEmbedding(nn.Module):
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
+ super().__init__()
96
+
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
+
103
+ # Build here to make `torch.jit.trace` work.
104
+ self._set_cos_sin_cache(
105
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
+ )
107
+
108
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
109
+ self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
111
+
112
+ freqs = torch.outer(t, self.inv_freq)
113
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
+ emb = torch.cat((freqs, freqs), dim=-1)
115
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
116
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
117
+
118
+ def forward(self, x, seq_len=None):
119
+ # x: [bs, num_attention_heads, seq_len, head_size]
120
+ if seq_len > self.max_seq_len_cached:
121
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122
+
123
+ return (
124
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
125
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
126
+ )
127
+
128
+
129
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
130
+ def rotate_half(x):
131
+ """Rotates half the hidden dims of the input."""
132
+ x1 = x[..., : x.shape[-1] // 2]
133
+ x2 = x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+
137
+ # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
138
+ # TODO @Arthur no longer copied from LLama after static cache
139
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
140
+ """Applies Rotary Position Embedding to the query and key tensors.
141
+
142
+ Args:
143
+ q (`torch.Tensor`): The query tensor.
144
+ k (`torch.Tensor`): The key tensor.
145
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
146
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
147
+ position_ids (`torch.Tensor`):
148
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
149
+ used to pass offsetted position ids when working with a KV-cache.
150
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
151
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
152
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
153
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
154
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
155
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
156
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
157
+ Returns:
158
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
159
+ """
160
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
161
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
162
+ q_embed = (q * cos) + (rotate_half(q) * sin)
163
+ k_embed = (k * cos) + (rotate_half(k) * sin)
164
+ return q_embed, k_embed
165
+
166
+
167
+ class MistralMLP(nn.Module):
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.config = config
171
+ self.hidden_size = config.hidden_size
172
+ self.intermediate_size = config.intermediate_size
173
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
174
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
175
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
176
+ self.act_fn = ACT2FN[config.hidden_act]
177
+
178
+ def forward(self, x):
179
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
180
+
181
+
182
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
183
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
184
+ """
185
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
186
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
187
+ """
188
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
189
+ if n_rep == 1:
190
+ return hidden_states
191
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
192
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
193
+
194
+
195
+ class MistralAttention(nn.Module):
196
+ """
197
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
198
+ and "Generating Long Sequences with Sparse Transformers".
199
+ """
200
+
201
+ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
202
+ super().__init__()
203
+ self.config = config
204
+ self.layer_idx = layer_idx
205
+ if layer_idx is None:
206
+ logger.warning_once(
207
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
208
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
209
+ "when creating this class."
210
+ )
211
+
212
+ self.hidden_size = config.hidden_size
213
+ self.num_heads = config.num_attention_heads
214
+ self.head_dim = self.hidden_size // self.num_heads
215
+ self.num_key_value_heads = config.num_key_value_heads
216
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
217
+ self.max_position_embeddings = config.max_position_embeddings
218
+ self.rope_theta = config.rope_theta
219
+ self.is_causal = True
220
+ self.attention_dropout = config.attention_dropout
221
+
222
+ if (self.head_dim * self.num_heads) != self.hidden_size:
223
+ raise ValueError(
224
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
225
+ f" and `num_heads`: {self.num_heads})."
226
+ )
227
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
228
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
229
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
230
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
231
+
232
+ self.rotary_emb = MistralRotaryEmbedding(
233
+ self.head_dim,
234
+ max_position_embeddings=self.max_position_embeddings,
235
+ base=self.rope_theta,
236
+ )
237
+
238
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
239
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ position_ids: Optional[torch.LongTensor] = None,
246
+ past_key_value: Optional[Cache] = None,
247
+ output_attentions: bool = False,
248
+ use_cache: bool = False,
249
+ **kwargs,
250
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
251
+ if "padding_mask" in kwargs:
252
+ warnings.warn(
253
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
254
+ )
255
+ bsz, q_len, _ = hidden_states.size()
256
+
257
+ query_states = self.q_proj(hidden_states)
258
+ key_states = self.k_proj(hidden_states)
259
+ value_states = self.v_proj(hidden_states)
260
+
261
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
262
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
263
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
264
+
265
+ kv_seq_len = key_states.shape[-2]
266
+ if past_key_value is not None:
267
+ if self.layer_idx is None:
268
+ raise ValueError(
269
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
270
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
271
+ "with a layer index."
272
+ )
273
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
274
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
275
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
276
+
277
+ if past_key_value is not None:
278
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
279
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
280
+
281
+ # repeat k/v heads if n_kv_heads < n_heads
282
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
283
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
284
+
285
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
286
+
287
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
288
+ raise ValueError(
289
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
290
+ f" {attn_weights.size()}"
291
+ )
292
+
293
+ if attention_mask is not None:
294
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
295
+ raise ValueError(
296
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
297
+ )
298
+
299
+ attn_weights = attn_weights + attention_mask
300
+
301
+ # upcast attention to fp32
302
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
303
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
304
+ attn_output = torch.matmul(attn_weights, value_states)
305
+
306
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
307
+ raise ValueError(
308
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
309
+ f" {attn_output.size()}"
310
+ )
311
+
312
+ attn_output = attn_output.transpose(1, 2).contiguous()
313
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
314
+
315
+ attn_output = self.o_proj(attn_output)
316
+
317
+ if not output_attentions:
318
+ attn_weights = None
319
+
320
+ return attn_output, attn_weights, past_key_value
321
+
322
+
323
+ class MistralFlashAttention2(MistralAttention):
324
+ """
325
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
326
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
327
+ flash attention and deal with padding tokens in case the input contains any of them.
328
+ """
329
+
330
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
331
+ def __init__(self, *args, **kwargs):
332
+ super().__init__(*args, **kwargs)
333
+
334
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
335
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
336
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
337
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_value: Optional[Cache] = None,
345
+ output_attentions: bool = False,
346
+ use_cache: bool = False,
347
+ **kwargs,
348
+ ):
349
+ if "padding_mask" in kwargs:
350
+ warnings.warn(
351
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
352
+ )
353
+
354
+ # overwrite attention_mask with padding_mask
355
+ attention_mask = kwargs.pop("padding_mask")
356
+ bsz, q_len, _ = hidden_states.size()
357
+
358
+ query_states = self.q_proj(hidden_states)
359
+ key_states = self.k_proj(hidden_states)
360
+ value_states = self.v_proj(hidden_states)
361
+
362
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
363
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
364
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
365
+
366
+ kv_seq_len = key_states.shape[-2]
367
+ if past_key_value is not None:
368
+ if self.layer_idx is None:
369
+ raise ValueError(
370
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
371
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
372
+ "with a layer index."
373
+ )
374
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
375
+
376
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
377
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
378
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
379
+
380
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
381
+
382
+ use_sliding_windows = (
383
+ _flash_supports_window_size
384
+ and getattr(self.config, "sliding_window", None) is not None
385
+ and kv_seq_len > self.config.sliding_window
386
+ )
387
+
388
+ if not _flash_supports_window_size:
389
+ logger.warning_once(
390
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
391
+ " make sure to upgrade flash-attn library."
392
+ )
393
+
394
+ if past_key_value is not None:
395
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
396
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
397
+ if (
398
+ getattr(self.config, "sliding_window", None) is not None
399
+ and kv_seq_len > self.config.sliding_window
400
+ and cache_has_contents
401
+ ):
402
+ slicing_tokens = 1 - self.config.sliding_window
403
+
404
+ past_key = past_key_value[self.layer_idx][0]
405
+ past_value = past_key_value[self.layer_idx][1]
406
+
407
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
408
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
409
+
410
+ if past_key.shape[-2] != self.config.sliding_window - 1:
411
+ raise ValueError(
412
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
413
+ f" {past_key.shape}"
414
+ )
415
+
416
+ if attention_mask is not None:
417
+ attention_mask = attention_mask[:, slicing_tokens:]
418
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
419
+
420
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
421
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
422
+
423
+ # repeat k/v heads if n_kv_heads < n_heads
424
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
425
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
426
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
427
+
428
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
429
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
430
+ # cast them back in float16 just to be sure everything works as expected.
431
+ input_dtype = query_states.dtype
432
+ if input_dtype == torch.float32:
433
+ if torch.is_autocast_enabled():
434
+ target_dtype = torch.get_autocast_gpu_dtype()
435
+ # Handle the case where the model is quantized
436
+ elif hasattr(self.config, "_pre_quantization_dtype"):
437
+ target_dtype = self.config._pre_quantization_dtype
438
+ else:
439
+ target_dtype = self.q_proj.weight.dtype
440
+
441
+ logger.warning_once(
442
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
443
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
444
+ f" {target_dtype}."
445
+ )
446
+
447
+ query_states = query_states.to(target_dtype)
448
+ key_states = key_states.to(target_dtype)
449
+ value_states = value_states.to(target_dtype)
450
+
451
+ # Reashape to the expected shape for Flash Attention
452
+ query_states = query_states.transpose(1, 2)
453
+ key_states = key_states.transpose(1, 2)
454
+ value_states = value_states.transpose(1, 2)
455
+
456
+ attn_output = self._flash_attention_forward(
457
+ query_states,
458
+ key_states,
459
+ value_states,
460
+ attention_mask,
461
+ q_len,
462
+ dropout=dropout_rate,
463
+ use_sliding_windows=use_sliding_windows,
464
+ )
465
+
466
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
467
+ attn_output = self.o_proj(attn_output)
468
+
469
+ if not output_attentions:
470
+ attn_weights = None
471
+
472
+ return attn_output, attn_weights, past_key_value
473
+
474
+ def _flash_attention_forward(
475
+ self,
476
+ query_states,
477
+ key_states,
478
+ value_states,
479
+ attention_mask,
480
+ query_length,
481
+ dropout=0.0,
482
+ softmax_scale=None,
483
+ use_sliding_windows=False,
484
+ ):
485
+ """
486
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
487
+ first unpad the input, then computes the attention scores and pad the final attention scores.
488
+
489
+ Args:
490
+ query_states (`torch.Tensor`):
491
+ Input query states to be passed to Flash Attention API
492
+ key_states (`torch.Tensor`):
493
+ Input key states to be passed to Flash Attention API
494
+ value_states (`torch.Tensor`):
495
+ Input value states to be passed to Flash Attention API
496
+ attention_mask (`torch.Tensor`):
497
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
498
+ position of padding tokens and 1 for the position of non-padding tokens.
499
+ dropout (`int`, *optional*):
500
+ Attention dropout
501
+ softmax_scale (`float`, *optional*):
502
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
503
+ use_sliding_windows (`bool`, *optional*):
504
+ Whether to activate sliding window attention.
505
+ """
506
+ if not self._flash_attn_uses_top_left_mask:
507
+ causal = self.is_causal
508
+ else:
509
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
510
+ causal = self.is_causal and query_length != 1
511
+
512
+ # Contains at least one padding token in the sequence
513
+ if attention_mask is not None:
514
+ batch_size = query_states.shape[0]
515
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
516
+ query_states, key_states, value_states, attention_mask, query_length
517
+ )
518
+
519
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
520
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
521
+
522
+ if not use_sliding_windows:
523
+ attn_output_unpad = flash_attn_varlen_func(
524
+ query_states,
525
+ key_states,
526
+ value_states,
527
+ cu_seqlens_q=cu_seqlens_q,
528
+ cu_seqlens_k=cu_seqlens_k,
529
+ max_seqlen_q=max_seqlen_in_batch_q,
530
+ max_seqlen_k=max_seqlen_in_batch_k,
531
+ dropout_p=dropout,
532
+ softmax_scale=softmax_scale,
533
+ causal=causal,
534
+ )
535
+ else:
536
+ attn_output_unpad = flash_attn_varlen_func(
537
+ query_states,
538
+ key_states,
539
+ value_states,
540
+ cu_seqlens_q=cu_seqlens_q,
541
+ cu_seqlens_k=cu_seqlens_k,
542
+ max_seqlen_q=max_seqlen_in_batch_q,
543
+ max_seqlen_k=max_seqlen_in_batch_k,
544
+ dropout_p=dropout,
545
+ softmax_scale=softmax_scale,
546
+ causal=causal,
547
+ window_size=(self.config.sliding_window, self.config.sliding_window),
548
+ )
549
+
550
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
551
+ else:
552
+ if not use_sliding_windows:
553
+ attn_output = flash_attn_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ dropout,
558
+ softmax_scale=softmax_scale,
559
+ causal=causal,
560
+ )
561
+ else:
562
+ attn_output = flash_attn_func(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ dropout,
567
+ softmax_scale=softmax_scale,
568
+ causal=causal,
569
+ window_size=(self.config.sliding_window, self.config.sliding_window),
570
+ )
571
+
572
+ return attn_output
573
+
574
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
576
+
577
+ # On the first iteration we need to properly re-create the padding mask
578
+ # by slicing it on the proper place
579
+ if kv_seq_len != attention_mask.shape[-1]:
580
+ attention_mask_num_tokens = attention_mask.shape[-1]
581
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
582
+
583
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
584
+
585
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
586
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
587
+
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q,
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
618
+ # TODO @Arthur no longer copied from LLama after static cache
619
+ class MistralSdpaAttention(MistralAttention):
620
+ """
621
+ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
622
+ `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
623
+ SDPA API.
624
+ """
625
+
626
+ # Adapted from MistralAttention.forward
627
+ def forward(
628
+ self,
629
+ hidden_states: torch.Tensor,
630
+ attention_mask: Optional[torch.Tensor] = None,
631
+ position_ids: Optional[torch.LongTensor] = None,
632
+ past_key_value: Optional[Cache] = None,
633
+ output_attentions: bool = False,
634
+ use_cache: bool = False,
635
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
636
+ if output_attentions:
637
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
638
+ logger.warning_once(
639
+ "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
640
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
641
+ )
642
+ return super().forward(
643
+ hidden_states=hidden_states,
644
+ attention_mask=attention_mask,
645
+ position_ids=position_ids,
646
+ past_key_value=past_key_value,
647
+ output_attentions=output_attentions,
648
+ use_cache=use_cache,
649
+ )
650
+
651
+ bsz, q_len, _ = hidden_states.size()
652
+
653
+ query_states = self.q_proj(hidden_states)
654
+ key_states = self.k_proj(hidden_states)
655
+ value_states = self.v_proj(hidden_states)
656
+
657
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
658
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
659
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
660
+
661
+ kv_seq_len = key_states.shape[-2]
662
+ if past_key_value is not None:
663
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
664
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
665
+
666
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
667
+
668
+ if past_key_value is not None:
669
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
670
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
671
+
672
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
673
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
674
+
675
+ if attention_mask is not None:
676
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
677
+ raise ValueError(
678
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
679
+ )
680
+
681
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
682
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
683
+ if query_states.device.type == "cuda" and attention_mask is not None:
684
+ query_states = query_states.contiguous()
685
+ key_states = key_states.contiguous()
686
+ value_states = value_states.contiguous()
687
+
688
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=attention_mask,
693
+ dropout_p=self.attention_dropout if self.training else 0.0,
694
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
695
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
696
+ )
697
+
698
+ attn_output = attn_output.transpose(1, 2).contiguous()
699
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
700
+
701
+ attn_output = self.o_proj(attn_output)
702
+
703
+ return attn_output, None, past_key_value
704
+
705
+
706
+ MISTRAL_ATTENTION_CLASSES = {
707
+ "eager": MistralAttention,
708
+ "flash_attention_2": MistralFlashAttention2,
709
+ "sdpa": MistralSdpaAttention,
710
+ }
711
+
712
+
713
+ class MistralDecoderLayer(nn.Module):
714
+ def __init__(self, config: MistralConfig, layer_idx: int):
715
+ super().__init__()
716
+ self.hidden_size = config.hidden_size
717
+
718
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
719
+
720
+ self.mlp = MistralMLP(config)
721
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
722
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
723
+
724
+ def forward(
725
+ self,
726
+ hidden_states: torch.Tensor,
727
+ attention_mask: Optional[torch.Tensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
730
+ output_attentions: Optional[bool] = False,
731
+ use_cache: Optional[bool] = False,
732
+ **kwargs,
733
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
734
+ if "padding_mask" in kwargs:
735
+ warnings.warn(
736
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
737
+ )
738
+ """
739
+ Args:
740
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
741
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
742
+ `(batch, sequence_length)` where padding elements are indicated by 0.
743
+ output_attentions (`bool`, *optional*):
744
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
745
+ returned tensors for more detail.
746
+ use_cache (`bool`, *optional*):
747
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
748
+ (see `past_key_values`).
749
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
750
+ """
751
+
752
+ residual = hidden_states
753
+
754
+ hidden_states = self.input_layernorm(hidden_states)
755
+
756
+ # Self Attention
757
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
758
+ hidden_states=hidden_states,
759
+ attention_mask=attention_mask,
760
+ position_ids=position_ids,
761
+ past_key_value=past_key_value,
762
+ output_attentions=output_attentions,
763
+ use_cache=use_cache,
764
+ )
765
+ hidden_states = residual + hidden_states
766
+
767
+ # Fully Connected
768
+ residual = hidden_states
769
+ hidden_states = self.post_attention_layernorm(hidden_states)
770
+ hidden_states = self.mlp(hidden_states)
771
+ hidden_states = residual + hidden_states
772
+
773
+ outputs = (hidden_states,)
774
+
775
+ if output_attentions:
776
+ outputs += (self_attn_weights,)
777
+
778
+ if use_cache:
779
+ outputs += (present_key_value,)
780
+
781
+ return outputs
782
+
783
+
784
+ MISTRAL_START_DOCSTRING = r"""
785
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
786
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
787
+ etc.)
788
+
789
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
790
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
791
+ and behavior.
792
+
793
+ Parameters:
794
+ config ([`MistralConfig`]):
795
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
796
+ load the weights associated with the model, only the configuration. Check out the
797
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
798
+ """
799
+
800
+
801
+ @add_start_docstrings(
802
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
803
+ MISTRAL_START_DOCSTRING,
804
+ )
805
+ class MistralPreTrainedModel(PreTrainedModel):
806
+ config_class = MistralConfig
807
+ base_model_prefix = "model"
808
+ supports_gradient_checkpointing = True
809
+ _no_split_modules = ["MistralDecoderLayer"]
810
+ _skip_keys_device_placement = "past_key_values"
811
+ _supports_flash_attn_2 = True
812
+ _supports_sdpa = True
813
+ _supports_cache_class = True
814
+
815
+ def _init_weights(self, module):
816
+ std = self.config.initializer_range
817
+ if isinstance(module, nn.Linear):
818
+ module.weight.data.normal_(mean=0.0, std=std)
819
+ if module.bias is not None:
820
+ module.bias.data.zero_()
821
+ elif isinstance(module, nn.Embedding):
822
+ module.weight.data.normal_(mean=0.0, std=std)
823
+ if module.padding_idx is not None:
824
+ module.weight.data[module.padding_idx].zero_()
825
+
826
+
827
+ MISTRAL_INPUTS_DOCSTRING = r"""
828
+ Args:
829
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
830
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
831
+ it.
832
+
833
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
834
+ [`PreTrainedTokenizer.__call__`] for details.
835
+
836
+ [What are input IDs?](../glossary#input-ids)
837
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
838
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
839
+
840
+ - 1 for tokens that are **not masked**,
841
+ - 0 for tokens that are **masked**.
842
+
843
+ [What are attention masks?](../glossary#attention-mask)
844
+
845
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
846
+ [`PreTrainedTokenizer.__call__`] for details.
847
+
848
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
849
+ `past_key_values`).
850
+
851
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
852
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
853
+ information on the default strategy.
854
+
855
+ - 1 indicates the head is **not masked**,
856
+ - 0 indicates the head is **masked**.
857
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
858
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
859
+ config.n_positions - 1]`.
860
+
861
+ [What are position IDs?](../glossary#position-ids)
862
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
863
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
864
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
865
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
866
+
867
+ Two formats are allowed:
868
+ - a [`~cache_utils.Cache`] instance;
869
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
870
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
871
+ cache format.
872
+
873
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
874
+ legacy cache format will be returned.
875
+
876
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
877
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
878
+ of shape `(batch_size, sequence_length)`.
879
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
880
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
881
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
882
+ model's internal embedding lookup matrix.
883
+ use_cache (`bool`, *optional*):
884
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
885
+ `past_key_values`).
886
+ output_attentions (`bool`, *optional*):
887
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
888
+ tensors for more detail.
889
+ output_hidden_states (`bool`, *optional*):
890
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
891
+ more detail.
892
+ return_dict (`bool`, *optional*):
893
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
894
+ """
895
+
896
+
897
+ @add_start_docstrings(
898
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
899
+ MISTRAL_START_DOCSTRING,
900
+ )
901
+ class MistralModel(MistralPreTrainedModel):
902
+ """
903
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
904
+
905
+ Args:
906
+ config: MistralConfig
907
+ """
908
+
909
+ def __init__(self, config: MistralConfig):
910
+ super().__init__(config)
911
+ self.padding_idx = config.pad_token_id
912
+ self.vocab_size = config.vocab_size
913
+
914
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
915
+ self.layers = nn.ModuleList(
916
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
917
+ )
918
+ self._attn_implementation = config._attn_implementation
919
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920
+
921
+ self.gradient_checkpointing = False
922
+ # Initialize weights and apply final processing
923
+ self.post_init()
924
+
925
+ def get_input_embeddings(self):
926
+ return self.embed_tokens
927
+
928
+ def set_input_embeddings(self, value):
929
+ self.embed_tokens = value
930
+
931
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
932
+ def forward(
933
+ self,
934
+ input_ids: torch.LongTensor = None,
935
+ attention_mask: Optional[torch.Tensor] = None,
936
+ position_ids: Optional[torch.LongTensor] = None,
937
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
938
+ inputs_embeds: Optional[torch.FloatTensor] = None,
939
+ use_cache: Optional[bool] = None,
940
+ output_attentions: Optional[bool] = None,
941
+ output_hidden_states: Optional[bool] = None,
942
+ return_dict: Optional[bool] = None,
943
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
944
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
945
+ output_hidden_states = (
946
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
947
+ )
948
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
949
+
950
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
951
+
952
+ # retrieve input_ids and inputs_embeds
953
+ if input_ids is not None and inputs_embeds is not None:
954
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
955
+ elif input_ids is not None:
956
+ batch_size, seq_length = input_ids.shape
957
+ elif inputs_embeds is not None:
958
+ batch_size, seq_length, _ = inputs_embeds.shape
959
+ else:
960
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
961
+
962
+ if self.gradient_checkpointing and self.training:
963
+ if use_cache:
964
+ logger.warning_once(
965
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
966
+ )
967
+ use_cache = False
968
+
969
+ past_key_values_length = 0
970
+
971
+ if use_cache:
972
+ use_legacy_cache = not isinstance(past_key_values, Cache)
973
+ if use_legacy_cache:
974
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
975
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
976
+
977
+ if position_ids is None:
978
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
979
+ position_ids = torch.arange(
980
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
981
+ )
982
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
983
+ else:
984
+ position_ids = position_ids.view(-1, seq_length).long()
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.embed_tokens(input_ids)
988
+
989
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
990
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
991
+ if is_padding_right:
992
+ raise ValueError(
993
+ "You are attempting to perform batched generation with padding_side='right'"
994
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
995
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
996
+ )
997
+
998
+ if self._attn_implementation == "flash_attention_2":
999
+ # 2d mask is passed through the layers
1000
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1001
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1002
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1003
+ # the manual implementation that requires a 4D causal mask in all cases.
1004
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1005
+ attention_mask,
1006
+ (batch_size, seq_length),
1007
+ inputs_embeds,
1008
+ past_key_values_length,
1009
+ )
1010
+ else:
1011
+ # 4d mask is passed through the layers
1012
+ attention_mask = _prepare_4d_causal_attention_mask(
1013
+ attention_mask,
1014
+ (batch_size, seq_length),
1015
+ inputs_embeds,
1016
+ past_key_values_length,
1017
+ sliding_window=self.config.sliding_window,
1018
+ )
1019
+
1020
+ hidden_states = inputs_embeds
1021
+
1022
+ # decoder layers
1023
+ all_hidden_states = () if output_hidden_states else None
1024
+ all_self_attns = () if output_attentions else None
1025
+ next_decoder_cache = None
1026
+
1027
+ for decoder_layer in self.layers:
1028
+ if output_hidden_states:
1029
+ all_hidden_states += (hidden_states,)
1030
+
1031
+ if self.gradient_checkpointing and self.training:
1032
+ layer_outputs = self._gradient_checkpointing_func(
1033
+ decoder_layer.__call__,
1034
+ hidden_states,
1035
+ attention_mask,
1036
+ position_ids,
1037
+ past_key_values,
1038
+ output_attentions,
1039
+ use_cache,
1040
+ )
1041
+ else:
1042
+ layer_outputs = decoder_layer(
1043
+ hidden_states,
1044
+ attention_mask=attention_mask,
1045
+ position_ids=position_ids,
1046
+ past_key_value=past_key_values,
1047
+ output_attentions=output_attentions,
1048
+ use_cache=use_cache,
1049
+ )
1050
+
1051
+ hidden_states = layer_outputs[0]
1052
+
1053
+ if use_cache:
1054
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1055
+
1056
+ if output_attentions:
1057
+ all_self_attns += (layer_outputs[1],)
1058
+
1059
+ hidden_states = self.norm(hidden_states)
1060
+
1061
+ # add hidden states from the last decoder layer
1062
+ if output_hidden_states:
1063
+ all_hidden_states += (hidden_states,)
1064
+
1065
+ next_cache = None
1066
+ if use_cache:
1067
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1068
+
1069
+ if not return_dict:
1070
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1071
+ return BaseModelOutputWithPast(
1072
+ last_hidden_state=hidden_states,
1073
+ past_key_values=next_cache,
1074
+ hidden_states=all_hidden_states,
1075
+ attentions=all_self_attns,
1076
+ )
1077
+
1078
+
1079
+ class MistralForCausalLM(MistralPreTrainedModel):
1080
+ _tied_weights_keys = ["lm_head.weight"]
1081
+
1082
+ def __init__(self, config):
1083
+ super().__init__(config)
1084
+ self.model = MistralModel(config)
1085
+ self.vocab_size = config.vocab_size
1086
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1087
+
1088
+ # Initialize weights and apply final processing
1089
+ self.post_init()
1090
+
1091
+ def get_input_embeddings(self):
1092
+ return self.model.embed_tokens
1093
+
1094
+ def set_input_embeddings(self, value):
1095
+ self.model.embed_tokens = value
1096
+
1097
+ def get_output_embeddings(self):
1098
+ return self.lm_head
1099
+
1100
+ def set_output_embeddings(self, new_embeddings):
1101
+ self.lm_head = new_embeddings
1102
+
1103
+ def set_decoder(self, decoder):
1104
+ self.model = decoder
1105
+
1106
+ def get_decoder(self):
1107
+ return self.model
1108
+
1109
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1110
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1111
+ def forward(
1112
+ self,
1113
+ input_ids: torch.LongTensor = None,
1114
+ attention_mask: Optional[torch.Tensor] = None,
1115
+ position_ids: Optional[torch.LongTensor] = None,
1116
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1118
+ labels: Optional[torch.LongTensor] = None,
1119
+ use_cache: Optional[bool] = None,
1120
+ output_attentions: Optional[bool] = None,
1121
+ output_hidden_states: Optional[bool] = None,
1122
+ return_dict: Optional[bool] = None,
1123
+ reduction: Optional[str] = "mean",
1124
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1125
+ r"""
1126
+ Args:
1127
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1128
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1129
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1130
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1131
+
1132
+ Returns:
1133
+
1134
+ Example:
1135
+
1136
+ ```python
1137
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
1138
+
1139
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
1140
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
1141
+
1142
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1143
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1144
+
1145
+ >>> # Generate
1146
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1147
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1148
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1149
+ ```"""
1150
+
1151
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1152
+ output_hidden_states = (
1153
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1154
+ )
1155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1156
+
1157
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1158
+ outputs = self.model(
1159
+ input_ids=input_ids,
1160
+ attention_mask=attention_mask,
1161
+ position_ids=position_ids,
1162
+ past_key_values=past_key_values,
1163
+ inputs_embeds=inputs_embeds,
1164
+ use_cache=use_cache,
1165
+ output_attentions=output_attentions,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+
1170
+ hidden_states = outputs[0]
1171
+ logits = self.lm_head(hidden_states)
1172
+ logits = logits.float()
1173
+
1174
+ loss = None
1175
+ if labels is not None:
1176
+ # Shift so that tokens < n predict n
1177
+ shift_logits = logits[..., :-1, :].contiguous()
1178
+ shift_labels = labels[..., 1:].contiguous()
1179
+ # Flatten the tokens
1180
+ loss_fct = CrossEntropyLoss(reduction=reduction)
1181
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1182
+ shift_labels = shift_labels.view(-1)
1183
+ # Enable model parallelism
1184
+ shift_labels = shift_labels.to(shift_logits.device)
1185
+ loss = loss_fct(shift_logits, shift_labels)
1186
+ if reduction == "none":
1187
+ loss = loss.view(logits.size(0), -1).mean(1)
1188
+ if not return_dict:
1189
+ output = (logits,) + outputs[1:]
1190
+ return (loss,) + output if loss is not None else output
1191
+
1192
+ return CausalLMOutputWithPast(
1193
+ loss=loss,
1194
+ logits=logits,
1195
+ past_key_values=outputs.past_key_values,
1196
+ hidden_states=outputs.hidden_states,
1197
+ attentions=outputs.attentions,
1198
+ )
1199
+
1200
+ def prepare_inputs_for_generation(
1201
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1202
+ ):
1203
+ # Omit tokens covered by past_key_values
1204
+ if past_key_values is not None:
1205
+ if isinstance(past_key_values, Cache):
1206
+ cache_length = past_key_values.get_seq_length()
1207
+ past_length = past_key_values.seen_tokens
1208
+ max_cache_length = past_key_values.get_max_length()
1209
+ else:
1210
+ cache_length = past_length = past_key_values[0][0].shape[2]
1211
+ max_cache_length = None
1212
+
1213
+ # Keep only the unprocessed tokens:
1214
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1215
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1216
+ # input)
1217
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1218
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1219
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1220
+ # input_ids based on the past_length.
1221
+ elif past_length < input_ids.shape[1]:
1222
+ input_ids = input_ids[:, past_length:]
1223
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1224
+
1225
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1226
+ if (
1227
+ max_cache_length is not None
1228
+ and attention_mask is not None
1229
+ and cache_length + input_ids.shape[1] > max_cache_length
1230
+ ):
1231
+ attention_mask = attention_mask[:, -max_cache_length:]
1232
+
1233
+ position_ids = kwargs.get("position_ids", None)
1234
+ if attention_mask is not None and position_ids is None:
1235
+ # create position_ids on the fly for batch generation
1236
+ position_ids = attention_mask.long().cumsum(-1) - 1
1237
+ position_ids.masked_fill_(attention_mask == 0, 1)
1238
+ if past_key_values:
1239
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1240
+
1241
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1242
+ if inputs_embeds is not None and past_key_values is None:
1243
+ model_inputs = {"inputs_embeds": inputs_embeds}
1244
+ else:
1245
+ model_inputs = {"input_ids": input_ids}
1246
+
1247
+ model_inputs.update(
1248
+ {
1249
+ "position_ids": position_ids,
1250
+ "past_key_values": past_key_values,
1251
+ "use_cache": kwargs.get("use_cache"),
1252
+ "attention_mask": attention_mask,
1253
+ }
1254
+ )
1255
+ return model_inputs
1256
+
1257
+ @staticmethod
1258
+ def _reorder_cache(past_key_values, beam_idx):
1259
+ reordered_past = ()
1260
+ for layer_past in past_key_values:
1261
+ reordered_past += (
1262
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1263
+ )
1264
+ return reordered_past
1265
+
1266
+
1267
+ @add_start_docstrings(
1268
+ """
1269
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
1270
+
1271
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1272
+ (e.g. GPT-2) do.
1273
+
1274
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1275
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1276
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1277
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1278
+ each row of the batch).
1279
+ """,
1280
+ MISTRAL_START_DOCSTRING,
1281
+ )
1282
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1283
+ class MistralForSequenceClassification(MistralPreTrainedModel):
1284
+ def __init__(self, config):
1285
+ super().__init__(config)
1286
+ self.num_labels = config.num_labels
1287
+ self.model = MistralModel(config)
1288
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1289
+
1290
+ # Initialize weights and apply final processing
1291
+ self.post_init()
1292
+
1293
+ def get_input_embeddings(self):
1294
+ return self.model.embed_tokens
1295
+
1296
+ def set_input_embeddings(self, value):
1297
+ self.model.embed_tokens = value
1298
+
1299
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1300
+ def forward(
1301
+ self,
1302
+ input_ids: torch.LongTensor = None,
1303
+ attention_mask: Optional[torch.Tensor] = None,
1304
+ position_ids: Optional[torch.LongTensor] = None,
1305
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1306
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1307
+ labels: Optional[torch.LongTensor] = None,
1308
+ use_cache: Optional[bool] = None,
1309
+ output_attentions: Optional[bool] = None,
1310
+ output_hidden_states: Optional[bool] = None,
1311
+ return_dict: Optional[bool] = None,
1312
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1313
+ r"""
1314
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1315
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1316
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1317
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1318
+ """
1319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1320
+
1321
+ transformer_outputs = self.model(
1322
+ input_ids,
1323
+ attention_mask=attention_mask,
1324
+ position_ids=position_ids,
1325
+ past_key_values=past_key_values,
1326
+ inputs_embeds=inputs_embeds,
1327
+ use_cache=use_cache,
1328
+ output_attentions=output_attentions,
1329
+ output_hidden_states=output_hidden_states,
1330
+ return_dict=return_dict,
1331
+ )
1332
+ hidden_states = transformer_outputs[0]
1333
+ logits = self.score(hidden_states)
1334
+
1335
+ if input_ids is not None:
1336
+ batch_size = input_ids.shape[0]
1337
+ else:
1338
+ batch_size = inputs_embeds.shape[0]
1339
+
1340
+ if self.config.pad_token_id is None and batch_size != 1:
1341
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1342
+ if self.config.pad_token_id is None:
1343
+ sequence_lengths = -1
1344
+ else:
1345
+ if input_ids is not None:
1346
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1347
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1348
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1349
+ sequence_lengths = sequence_lengths.to(logits.device)
1350
+ else:
1351
+ sequence_lengths = -1
1352
+
1353
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ labels = labels.to(logits.device)
1358
+ if self.config.problem_type is None:
1359
+ if self.num_labels == 1:
1360
+ self.config.problem_type = "regression"
1361
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1362
+ self.config.problem_type = "single_label_classification"
1363
+ else:
1364
+ self.config.problem_type = "multi_label_classification"
1365
+
1366
+ if self.config.problem_type == "regression":
1367
+ loss_fct = MSELoss()
1368
+ if self.num_labels == 1:
1369
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1370
+ else:
1371
+ loss = loss_fct(pooled_logits, labels)
1372
+ elif self.config.problem_type == "single_label_classification":
1373
+ loss_fct = CrossEntropyLoss()
1374
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1375
+ elif self.config.problem_type == "multi_label_classification":
1376
+ loss_fct = BCEWithLogitsLoss()
1377
+ loss = loss_fct(pooled_logits, labels)
1378
+ if not return_dict:
1379
+ output = (pooled_logits,) + transformer_outputs[1:]
1380
+ return ((loss,) + output) if loss is not None else output
1381
+
1382
+ return SequenceClassifierOutputWithPast(
1383
+ loss=loss,
1384
+ logits=pooled_logits,
1385
+ past_key_values=transformer_outputs.past_key_values,
1386
+ hidden_states=transformer_outputs.hidden_states,
1387
+ attentions=transformer_outputs.attentions,
1388
+ )
optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from minigpt4_video.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=total_cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
registry.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from minigpt4.common.registry import registry
31
+ from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from minigpt4.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from minigpt4.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from minigpt4.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ # from minigpt4.models import BaseModel
96
+
97
+ # assert issubclass(
98
+ # model_cls, BaseModel
99
+ # ), "All models must inherit BaseModel class"
100
+
101
+ if name in cls.mapping["model_name_mapping"]:
102
+ raise KeyError(
103
+ "Name '{}' already registered for {}.".format(
104
+ name, cls.mapping["model_name_mapping"][name]
105
+ )
106
+ )
107
+ cls.mapping["model_name_mapping"][name] = model_cls
108
+ return model_cls
109
+
110
+ return wrap
111
+
112
+ @classmethod
113
+ def register_processor(cls, name):
114
+ r"""Register a processor to registry with key 'name'
115
+
116
+ Args:
117
+ name: Key with which the task will be registered.
118
+
119
+ Usage:
120
+
121
+ from minigpt4.common.registry import registry
122
+ """
123
+
124
+ def wrap(processor_cls):
125
+ from minigpt4.processors import BaseProcessor
126
+
127
+ assert issubclass(
128
+ processor_cls, BaseProcessor
129
+ ), "All processors must inherit BaseProcessor class"
130
+ if name in cls.mapping["processor_name_mapping"]:
131
+ raise KeyError(
132
+ "Name '{}' already registered for {}.".format(
133
+ name, cls.mapping["processor_name_mapping"][name]
134
+ )
135
+ )
136
+ cls.mapping["processor_name_mapping"][name] = processor_cls
137
+ return processor_cls
138
+
139
+ return wrap
140
+
141
+ @classmethod
142
+ def register_lr_scheduler(cls, name):
143
+ r"""Register a model to registry with key 'name'
144
+
145
+ Args:
146
+ name: Key with which the task will be registered.
147
+
148
+ Usage:
149
+
150
+ from minigpt4.common.registry import registry
151
+ """
152
+
153
+ def wrap(lr_sched_cls):
154
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
155
+ raise KeyError(
156
+ "Name '{}' already registered for {}.".format(
157
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
158
+ )
159
+ )
160
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
161
+ return lr_sched_cls
162
+
163
+ return wrap
164
+
165
+ @classmethod
166
+ def register_runner(cls, name):
167
+ r"""Register a model to registry with key 'name'
168
+
169
+ Args:
170
+ name: Key with which the task will be registered.
171
+
172
+ Usage:
173
+
174
+ from minigpt4.common.registry import registry
175
+ """
176
+
177
+ def wrap(runner_cls):
178
+ if name in cls.mapping["runner_name_mapping"]:
179
+ raise KeyError(
180
+ "Name '{}' already registered for {}.".format(
181
+ name, cls.mapping["runner_name_mapping"][name]
182
+ )
183
+ )
184
+ cls.mapping["runner_name_mapping"][name] = runner_cls
185
+ return runner_cls
186
+
187
+ return wrap
188
+
189
+ @classmethod
190
+ def register_path(cls, name, path):
191
+ r"""Register a path to registry with key 'name'
192
+
193
+ Args:
194
+ name: Key with which the path will be registered.
195
+
196
+ Usage:
197
+
198
+ from minigpt4.common.registry import registry
199
+ """
200
+ assert isinstance(path, str), "All path must be str."
201
+ if name in cls.mapping["paths"]:
202
+ raise KeyError("Name '{}' already registered.".format(name))
203
+ cls.mapping["paths"][name] = path
204
+
205
+ @classmethod
206
+ def register(cls, name, obj):
207
+ r"""Register an item to registry with key 'name'
208
+
209
+ Args:
210
+ name: Key with which the item will be registered.
211
+
212
+ Usage::
213
+
214
+ from minigpt4.common.registry import registry
215
+
216
+ registry.register("config", {})
217
+ """
218
+ path = name.split(".")
219
+ current = cls.mapping["state"]
220
+
221
+ for part in path[:-1]:
222
+ if part not in current:
223
+ current[part] = {}
224
+ current = current[part]
225
+
226
+ current[path[-1]] = obj
227
+
228
+ # @classmethod
229
+ # def get_trainer_class(cls, name):
230
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
231
+
232
+ @classmethod
233
+ def get_builder_class(cls, name):
234
+ return cls.mapping["builder_name_mapping"].get(name, None)
235
+
236
+ @classmethod
237
+ def get_model_class(cls, name):
238
+ return cls.mapping["model_name_mapping"].get(name, None)
239
+
240
+ @classmethod
241
+ def get_task_class(cls, name):
242
+ return cls.mapping["task_name_mapping"].get(name, None)
243
+
244
+ @classmethod
245
+ def get_processor_class(cls, name):
246
+ return cls.mapping["processor_name_mapping"].get(name, None)
247
+
248
+ @classmethod
249
+ def get_lr_scheduler_class(cls, name):
250
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
251
+
252
+ @classmethod
253
+ def get_runner_class(cls, name):
254
+ return cls.mapping["runner_name_mapping"].get(name, None)
255
+
256
+ @classmethod
257
+ def list_runners(cls):
258
+ return sorted(cls.mapping["runner_name_mapping"].keys())
259
+
260
+ @classmethod
261
+ def list_models(cls):
262
+ return sorted(cls.mapping["model_name_mapping"].keys())
263
+
264
+ @classmethod
265
+ def list_tasks(cls):
266
+ return sorted(cls.mapping["task_name_mapping"].keys())
267
+
268
+ @classmethod
269
+ def list_processors(cls):
270
+ return sorted(cls.mapping["processor_name_mapping"].keys())
271
+
272
+ @classmethod
273
+ def list_lr_schedulers(cls):
274
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
275
+
276
+ @classmethod
277
+ def list_datasets(cls):
278
+ return sorted(cls.mapping["builder_name_mapping"].keys())
279
+
280
+ @classmethod
281
+ def get_path(cls, name):
282
+ return cls.mapping["paths"].get(name, None)
283
+
284
+ @classmethod
285
+ def get(cls, name, default=None, no_warning=False):
286
+ r"""Get an item from registry with key 'name'
287
+
288
+ Args:
289
+ name (string): Key whose value needs to be retrieved.
290
+ default: If passed and key is not in registry, default value will
291
+ be returned with a warning. Default: None
292
+ no_warning (bool): If passed as True, warning when key doesn't exist
293
+ will not be generated. Useful for MMF's
294
+ internal operations. Default: False
295
+ """
296
+ original_name = name
297
+ name = name.split(".")
298
+ value = cls.mapping["state"]
299
+ for subname in name:
300
+ value = value.get(subname, default)
301
+ if value is default:
302
+ break
303
+
304
+ if (
305
+ "writer" in cls.mapping["state"]
306
+ and value == default
307
+ and no_warning is False
308
+ ):
309
+ cls.mapping["state"]["writer"].warning(
310
+ "Key {} is not present in registry, returning default value "
311
+ "of {}".format(original_name, default)
312
+ )
313
+ return value
314
+
315
+ @classmethod
316
+ def unregister(cls, name):
317
+ r"""Remove an item from registry with key 'name'
318
+
319
+ Args:
320
+ name: Key which needs to be removed.
321
+ Usage::
322
+
323
+ from mmf.common.registry import registry
324
+
325
+ config = registry.unregister("config")
326
+ """
327
+ return cls.mapping["state"].pop(name, None)
328
+
329
+
330
+ registry = Registry()
utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from minigpt4_video.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb