jiajunlong commited on
Commit
f90ff9d
1 Parent(s): 1c7edc5

Update modeling_tinyllava_elm.py

Browse files
Files changed (1) hide show
  1. modeling_tinyllava_elm.py +1287 -6
modeling_tinyllava_elm.py CHANGED
@@ -5,16 +5,16 @@ import re
5
 
6
  import torch
7
  import torch.utils.checkpoint
8
- from torch import nn
9
 
10
  from transformers import PreTrainedModel
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
  from transformers.generation.utils import GenerateOutput
13
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
14
 
15
- from configuration import TinyLlavaConfig
16
- from utils import *
17
- from modeling_elm import OpenELMForCausalLM
18
 
19
  # from tinyllava.utils.data_utils import get_value_from_kwargs
20
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
@@ -22,6 +22,1287 @@ WORKER_HEART_BEAT_INTERVAL = 15
22
 
23
  LOGDIR = "."
24
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ACT_TYPE = {
27
  'relu': nn.ReLU,
@@ -409,5 +1690,5 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
409
 
410
 
411
 
412
-
413
-
 
5
 
6
  import torch
7
  import torch.utils.checkpoint
8
+ from torch import nn, Tensor
9
 
10
  from transformers import PreTrainedModel
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
  from transformers.generation.utils import GenerateOutput
13
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
14
 
15
+ from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
16
+
17
+ from transformers import AutoConfig, AutoModelForCausalLM
18
 
19
  # from tinyllava.utils.data_utils import get_value_from_kwargs
20
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
 
22
 
23
  LOGDIR = "."
24
  import os
25
+ #
26
+ # For licensing see accompanying LICENSE file.
27
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
28
+ #
29
+
30
+ from torch.nn import CrossEntropyLoss
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ )
36
+ from transformers.utils import logging
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ # this import has to be relative, otherwise, when setting trust_remote_code=True
41
+ # huggingface transformers won't be able to load the module correctly
42
+ from numbers import Number
43
+ from typing import List, Optional, Union
44
+
45
+ import numpy as np
46
+ from transformers import PretrainedConfig, AutoTokenizer
47
+
48
+
49
+ def make_divisible(
50
+ v: Union[float, int],
51
+ divisor: Optional[int] = 8,
52
+ min_value: Optional[Union[float, int]] = None,
53
+ ) -> Union[float, int]:
54
+ """
55
+ This function is taken from the original tf repo.
56
+ It ensures that all layers have a channel number that is divisible by the divisor
57
+ It can be seen at:
58
+ https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
59
+ Args:
60
+ v: input value
61
+ divisor: default to 8
62
+ min_value: minimum divisor value
63
+ Returns:
64
+ new_v: new divisible value
65
+ """
66
+ if min_value is None:
67
+ min_value = divisor
68
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
69
+ # Make sure that round down does not go down by more than 10%.
70
+ if new_v < 0.9 * v:
71
+ new_v += divisor
72
+ return new_v
73
+
74
+
75
+ def compute_heads(model_dim: int, head_dim: int) -> int:
76
+ """Compute the number of heads.
77
+ Args:
78
+ model_dim: Model dimension.
79
+ head_dim: Head dimension.
80
+ Returns:
81
+ An integer denoting number of heads in multi-head attention is returned.
82
+ Raises:
83
+ ValueError: if model dimension is not divisible by head dimension.
84
+ """
85
+ if model_dim % head_dim == 0:
86
+ return model_dim // head_dim
87
+ else:
88
+ raise ValueError(
89
+ f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
90
+ )
91
+
92
+
93
+ OpenELM_CONFIGS = {
94
+ "OpenELM-270M": dict(
95
+ num_transformer_layers=16,
96
+ model_dim=1280,
97
+ head_dim=64,
98
+ num_gqa_groups=4,
99
+ normalize_qk_projections=True,
100
+ share_input_output_layers=True,
101
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
102
+ ffn_multipliers=(0.5, 4.0),
103
+ qkv_multipliers=(0.5, 1.0),
104
+ ),
105
+ "OpenELM-450M": dict(
106
+ num_transformer_layers=20,
107
+ model_dim=1536,
108
+ head_dim=64,
109
+ num_gqa_groups=4,
110
+ normalize_qk_projections=True,
111
+ share_input_output_layers=True,
112
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
113
+ ffn_multipliers=(0.5, 4.0),
114
+ qkv_multipliers=(0.5, 1.0),
115
+ ),
116
+ "OpenELM-1_1B": dict(
117
+ num_transformer_layers=28,
118
+ model_dim=2048,
119
+ head_dim=64,
120
+ num_gqa_groups=4,
121
+ normalize_qk_projections=True,
122
+ share_input_output_layers=True,
123
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
124
+ ffn_multipliers=(0.5, 4.0),
125
+ qkv_multipliers=(0.5, 1.0),
126
+ ),
127
+ "OpenELM-3B": dict(
128
+ num_transformer_layers=36,
129
+ model_dim=3072,
130
+ head_dim=128,
131
+ num_gqa_groups=4,
132
+ normalize_qk_projections=True,
133
+ share_input_output_layers=True,
134
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
135
+ ffn_multipliers=(0.5, 4.0),
136
+ qkv_multipliers=(0.5, 1.0),
137
+ ),
138
+ }
139
+
140
+
141
+ class OpenELMConfig(PretrainedConfig):
142
+ r"""
143
+ This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
144
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
145
+ documentation from [`PretrainedConfig`] for more information.
146
+ Args:
147
+ vocab_size (`int`, *optional*, defaults to 32000):
148
+ Vocabulary size of the OpenELM model.
149
+ max_context_length (`int`, *optional*, defaults to 2048):
150
+ Maximum number of input tokens.
151
+ num_transformer_layers (`int`, *optional*, defaults to 12):
152
+ Number of hidden layers in the Transformer decoder.
153
+ model_dim (`int`, *optional*, defaults to 2048):
154
+ Dimension of the hidden representations.
155
+ head_dim (`int`, *optional*, defaults to 128):
156
+ The attention head dimension.
157
+ qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
158
+ If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
159
+ resulting in uniform allocation of parameters.
160
+ If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
161
+ assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
162
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
163
+ num_query_heads (`Union[int, None]`, *optional*, defaults to None):
164
+ The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
165
+ num_gqa_groups (`int`, *optional*, defaults to 1):
166
+ This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
167
+ When num_gqa_groups == 1, then it is multi-head attention.
168
+ When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
169
+ When num_gqa_groups == num_heads, then it is multi-query attention
170
+ ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
171
+ Feed-forward network (FFN) multipliers.
172
+ If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
173
+ resulting in uniform allocation of parameters.
174
+ If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
175
+ assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
176
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
177
+ ffn_with_glu (`bool`, *optional*, defaults to True):
178
+ Whether to use FFN with Gated Linear Unit (GLU)
179
+ ffn_dim_divisor (`int`, *optional*, defaults to 256):
180
+ The ffn layer dimension divisor.
181
+ activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
182
+ The non-linear activation function (function or string) in the decoder.
183
+ normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
184
+ Type of normalization layer.
185
+ normalize_qk_projections (`bool`, *optional*, defaults to False):
186
+ Whether to normalize queries and keys after projections
187
+ share_input_output_layers (`bool`, *optional*, defaults to False):
188
+ Whether to share the embedding between input and output linear layer
189
+ rope_freq_constant (`int`, *optional*, defaults to 10000):
190
+ The base period of the RoPE embeddings.
191
+ rope_max_length (`int`, *optional*, defaults to 4096):
192
+ That rope_max_length is set to twice of max_context_length.
193
+ This allows flexibility in token lengths during training or fine-tuning.
194
+ initializer_range (`float`, *optional*, defaults to 0.02):
195
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
196
+ use_cache (`bool`, *optional*, defaults to `True`):
197
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
198
+ relevant if `config.is_decoder=True`.
199
+ bos_token_id (`int`, *optional*, defaults to 2):
200
+ Beginning of stream token id.
201
+ eos_token_id (`int`, *optional*, defaults to 1):
202
+ End of stream token id.
203
+ """
204
+
205
+ model_type = "openelm"
206
+
207
+ def __init__(
208
+ self,
209
+ vocab_size: int = 32000,
210
+ max_context_length: int = 2048,
211
+ num_transformer_layers: int = 12,
212
+ model_dim: int = 2048,
213
+ head_dim: int = 128,
214
+ qkv_multipliers: Union[Number, List[Number]] = 1.0,
215
+ num_query_heads: Union[int, None] = None,
216
+ num_gqa_groups: int = 1,
217
+ ffn_multipliers: Union[Number, List[Number]] = 4.0,
218
+ ffn_with_glu: bool = True,
219
+ ffn_dim_divisor: int = 256,
220
+ activation_fn_name: str = "swish",
221
+ normalization_layer_name: str = "rms_norm",
222
+ normalize_qk_projections: bool = False,
223
+ share_input_output_layers: bool = False,
224
+ rope_freq_constant: int = 10000,
225
+ rope_max_length: int = 4096,
226
+ initializer_range: float = 0.02,
227
+ use_cache: bool = True,
228
+ bos_token_id: int = 1,
229
+ eos_token_id: int = 2,
230
+ **kwargs,
231
+ ) -> None:
232
+ self.vocab_size = vocab_size
233
+ self.max_context_length = max_context_length
234
+ self.num_transformer_layers = num_transformer_layers
235
+ self.model_dim = model_dim
236
+ self.head_dim = head_dim
237
+ self.qkv_multipliers = qkv_multipliers
238
+ self.num_query_heads = num_query_heads
239
+ self.num_gqa_groups = num_gqa_groups
240
+ self.ffn_multipliers = ffn_multipliers
241
+ self.ffn_with_glu = ffn_with_glu
242
+ self.ffn_dim_divisor = ffn_dim_divisor
243
+ self.activation_fn_name = activation_fn_name
244
+ self.normalization_layer_name = normalization_layer_name
245
+ self.normalize_qk_projections = normalize_qk_projections
246
+ self.share_input_output_layers = share_input_output_layers
247
+ self.rope_freq_constant = rope_freq_constant
248
+ self.rope_max_length = rope_max_length
249
+ self.num_query_heads = (
250
+ compute_heads(model_dim=model_dim, head_dim=head_dim)
251
+ if num_query_heads is None
252
+ else num_query_heads
253
+ )
254
+ self.initializer_range = initializer_range
255
+
256
+ self.__post_init__()
257
+ super().__init__(
258
+ use_cache=use_cache,
259
+ bos_token_id=bos_token_id,
260
+ eos_token_id=eos_token_id,
261
+ **kwargs,
262
+ )
263
+
264
+ def __post_init__(self) -> None:
265
+ if self.num_gqa_groups is not None:
266
+ head_multiple_of = self.num_gqa_groups
267
+ else:
268
+ head_multiple_of = 2
269
+
270
+ if isinstance(self.qkv_multipliers, Number):
271
+ # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
272
+ qkv_dim = make_divisible(
273
+ self.model_dim * self.qkv_multipliers,
274
+ divisor=self.head_dim * head_multiple_of,
275
+ )
276
+ query_dims = [int(qkv_dim)] * self.num_transformer_layers
277
+
278
+ elif (
279
+ isinstance(self.qkv_multipliers, (tuple, list))
280
+ and len(self.qkv_multipliers) == 2
281
+ ):
282
+ # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
283
+ # This results in variable allocation of parameters in attention layer.
284
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
285
+ qkv_multipliers = [
286
+ round(v, 2)
287
+ for v in np.linspace(
288
+ self.qkv_multipliers[0],
289
+ self.qkv_multipliers[1],
290
+ num=self.num_transformer_layers,
291
+ dtype=float,
292
+ )
293
+ ]
294
+ # Make sure that scaled model dimension is divisible by scaled head dimension.
295
+ query_dims = [
296
+ int(
297
+ make_divisible(
298
+ self.model_dim * m, divisor=self.head_dim * head_multiple_of
299
+ )
300
+ )
301
+ for m in qkv_multipliers
302
+ ]
303
+ else:
304
+ raise NotImplementedError(
305
+ f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
306
+ )
307
+
308
+ # compute the number of query, key, and value heads
309
+ # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
310
+ # For group query attention, the number of key and value heads are the same.
311
+ self.num_query_heads = [
312
+ int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
313
+ ]
314
+ self.num_kv_heads = [
315
+ q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
316
+ ]
317
+
318
+ # Feed-forward network (FFN) multipliers
319
+ if isinstance(self.ffn_multipliers, Number):
320
+ # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
321
+ self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
322
+ elif isinstance(self.ffn_multipliers, (tuple, list)):
323
+ # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
324
+ # This results in variable allocation of parameters in FFN layer.
325
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
326
+ if len(self.ffn_multipliers) == 2:
327
+ self.ffn_multipliers = [
328
+ round(v, 2)
329
+ for v in np.linspace(
330
+ self.ffn_multipliers[0],
331
+ self.ffn_multipliers[1],
332
+ num=self.num_transformer_layers,
333
+ dtype=float,
334
+ )
335
+ ]
336
+ else:
337
+ assert (
338
+ len(self.ffn_multipliers) == self.num_transformer_layers
339
+ ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
340
+ else:
341
+ raise NotImplementedError(
342
+ f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
343
+ )
344
+
345
+ # check num_query_heads divisible by num_kv_heads for every layer
346
+ for layer_idx in range(len(query_dims)):
347
+ assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
348
+
349
+ class OpenELMRMSNorm(nn.Module):
350
+ def __init__(self, num_features: int, eps: float = 1e-6):
351
+ """
352
+ Initialize the OpenELMRMSNorm normalization layer.
353
+ Args:
354
+ dim (int): The dimension of the input tensor.
355
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
356
+ Attributes:
357
+ eps (float): A small value added to the denominator for numerical stability.
358
+ weight (nn.Parameter): Learnable scaling parameter.
359
+ """
360
+ super().__init__()
361
+ self.eps = eps
362
+ self.weight = nn.Parameter(torch.ones(num_features))
363
+ self.num_features = num_features
364
+
365
+ def _norm(self, x: Tensor) -> Tensor:
366
+ """
367
+ Apply the OpenELMRMSNorm normalization to the input tensor.
368
+ Args:
369
+ x (torch.Tensor): The input tensor.
370
+ Returns:
371
+ torch.Tensor: The normalized tensor.
372
+ """
373
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
374
+
375
+ def forward(self, x: Tensor) -> Tensor:
376
+ """
377
+ Forward pass through the OpenELMRMSNorm layer.
378
+ Args:
379
+ x (torch.Tensor): The input tensor.
380
+ Returns:
381
+ torch.Tensor: The output tensor after applying OpenELMRMSNorm.
382
+ """
383
+ output = self._norm(x.float()).type_as(x)
384
+ return output * self.weight
385
+
386
+ def extra_repr(self) -> str:
387
+ return (
388
+ super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
389
+ )
390
+
391
+
392
+ class OpenELMPreTrainedModel(PreTrainedModel):
393
+ config_class = OpenELMConfig
394
+ base_model_prefix = "transformer"
395
+ supports_gradient_checkpointing = True
396
+ _no_split_modules = ["OpenELMDecoderLayer"]
397
+ _skip_keys_device_placement = "past_key_values"
398
+
399
+ def __init__(self, *inputs, **kwargs) -> None:
400
+ super().__init__(*inputs, **kwargs)
401
+
402
+ def _init_weights(self, module: nn.Module) -> None:
403
+ """Initialize the weights."""
404
+ if isinstance(module, nn.Linear):
405
+ # Slightly different from the TF version which uses truncated_normal for initialization
406
+ # cf https://github.com/pytorch/pytorch/pull/5617
407
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
408
+ if module.bias is not None:
409
+ module.bias.data.zero_()
410
+ elif isinstance(module, nn.Embedding):
411
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
412
+ if module.padding_idx is not None:
413
+ module.weight.data[module.padding_idx].zero_()
414
+ elif isinstance(module, OpenELMRMSNorm):
415
+ module.weight.data.fill_(1.0)
416
+
417
+
418
+ def _rotate_half(x: Tensor) -> Tensor:
419
+ x1, x2 = x.chunk(2, dim=-1)
420
+ return torch.cat((-x2, x1), dim=-1)
421
+
422
+
423
+ def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
424
+ return (x * pos_cos) + (_rotate_half(x) * pos_sin)
425
+
426
+
427
+ class OpenELMRotaryEmbedding(torch.nn.Module):
428
+ """
429
+ The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
430
+ RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
431
+ explicit relative positional dependencies.
432
+ Args:
433
+ model_dim: The dimensionality of the model's hidden state.
434
+ max_seq_length: Maximum sequence length.
435
+ freq_constant: A constant used for computing frequencies.
436
+ """
437
+
438
+ def __init__(
439
+ self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
440
+ ) -> None:
441
+ inv_freq = 1.0 / (
442
+ freq_constant
443
+ ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
444
+ )
445
+ super().__init__()
446
+
447
+ self.model_dim = model_dim
448
+ self.freq_constant = freq_constant
449
+ self.max_seq_length = max_seq_length
450
+
451
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
452
+ self._cached_cos = None
453
+ self._cached_sin = None
454
+ self._cached_seq_length = max_seq_length
455
+ self._compute_sin_cos_embeddings(max_seq_length)
456
+
457
+ def extra_repr(self) -> str:
458
+ return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
459
+
460
+ def _compute_sin_cos_embeddings(
461
+ self,
462
+ key_len: int,
463
+ key_device: torch.device = torch.device("cpu"),
464
+ key_dtype: torch.dtype = torch.float32,
465
+ ) -> None:
466
+ """
467
+ Compute sine and cos embeddings.
468
+ Args:
469
+ key_len: Number of tokens in the key embeddings in the transformer model.
470
+ device: Device where the key embeddings are stored.
471
+ key_dtype: Data type of the key embeddings.
472
+ Returns:
473
+ None
474
+ ...note:
475
+ We recalculate the sine and cosine embeddings if any of the following conditions are met:
476
+ 1. The number of tokens in key embeddings are greater than the cached sequence length.
477
+ 2. Sine and cosine caches are empty.
478
+ 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
479
+ """
480
+ if (
481
+ key_len > self._cached_seq_length
482
+ or self._cached_cos is None
483
+ or (self._cached_cos is not None and self._cached_cos.device != key_device)
484
+ or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
485
+ or self._cached_sin is None
486
+ or (self._cached_sin is not None and self._cached_sin.device != key_device)
487
+ or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
488
+ ):
489
+ self._cached_seq_length = max(key_len, self._cached_seq_length)
490
+
491
+ # The shape of 'pos_index' is [number of key tokens]
492
+ pos_index = torch.arange(
493
+ self._cached_seq_length,
494
+ dtype=torch.float32,
495
+ device=self.inv_freq.device,
496
+ )
497
+ # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
498
+ pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
499
+ # The shape of 'emb' is [number of key tokens, model dimension]
500
+ emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
501
+
502
+ # the shape of cos and sin embeddings is [number of key tokens, model_dim]
503
+ cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
504
+ sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
505
+
506
+ # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
507
+ self._cached_cos = cos_emb[None, None, :, :]
508
+ self._cached_sin = sin_emb[None, None, :, :]
509
+
510
+ def forward(
511
+ self,
512
+ query: torch.Tensor,
513
+ key: torch.Tensor,
514
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
515
+ """
516
+ The forward function of RoPE embeddings.
517
+ Args:
518
+ query: Query embeddings in the transformer model. The shape of query embeddings is
519
+ [Batch, number of query heads, number of query tokens, model dimension].
520
+ key: Key embeddings in the transformer model. The shape of key embeddings is
521
+ [Batch, number of key heads, number of key tokens, model dimension].
522
+ Returns:
523
+ A tuple containing the query and key embeddings with positional information. The shape of the returned query
524
+ and key embeddings is the same as the input query and key embeddings respectively.
525
+ ...note:
526
+ The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
527
+ are casted to original input datatype.
528
+ """
529
+ dim = key.shape[-1]
530
+ key_len = key.shape[2]
531
+ query_len = query.shape[2]
532
+
533
+ assert dim == self.model_dim
534
+ assert key.device == query.device
535
+ assert key.dtype == query.dtype
536
+
537
+ # In the context of self-attention, the lengths of keys and queries are equal.
538
+ # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
539
+ # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
540
+ # represent embeddings of previous tokens and the current token, while the query corresponds
541
+ # to the embedding of the current token only.
542
+ assert (
543
+ key_len >= query_len
544
+ ), "Number of keys has to be greater than or equal to number of queries."
545
+
546
+ query_float = query.float()
547
+ key_float = key.float()
548
+
549
+ self._compute_sin_cos_embeddings(
550
+ key_len, key_device=key_float.device, key_dtype=key_float.dtype
551
+ )
552
+ query_float = _apply_rotary_pos_emb(
553
+ x=query_float,
554
+ pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
555
+ pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
556
+ )
557
+ key_float = _apply_rotary_pos_emb(
558
+ x=key_float,
559
+ pos_sin=self._cached_sin[..., :key_len, :],
560
+ pos_cos=self._cached_cos[..., :key_len, :],
561
+ )
562
+
563
+ return query_float.type_as(query), key_float.type_as(key)
564
+
565
+
566
+ class OpenELMMultiHeadCausalAttention(nn.Module):
567
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
568
+ super().__init__()
569
+ self.layer_idx = layer_idx
570
+ head_dim = config.head_dim
571
+ q_heads = config.num_query_heads[layer_idx]
572
+ k_heads = config.num_kv_heads[layer_idx]
573
+ v_heads = config.num_kv_heads[layer_idx]
574
+
575
+ self.qkv_proj = nn.Linear(
576
+ in_features=config.model_dim,
577
+ out_features=(q_heads + k_heads + v_heads) * head_dim,
578
+ bias=False,
579
+ )
580
+
581
+ self.pos_embedding = OpenELMRotaryEmbedding(
582
+ model_dim=config.head_dim,
583
+ max_seq_length=config.rope_max_length,
584
+ freq_constant=config.rope_freq_constant,
585
+ )
586
+
587
+ if config.normalize_qk_projections:
588
+ self.q_norm = OpenELMRMSNorm(
589
+ num_features=config.head_dim,
590
+ )
591
+ self.k_norm = OpenELMRMSNorm(
592
+ num_features=config.head_dim,
593
+ )
594
+ else:
595
+ self.q_norm = None
596
+ self.k_norm = None
597
+
598
+ self.out_proj = nn.Linear(
599
+ in_features=q_heads * head_dim,
600
+ out_features=config.model_dim,
601
+ bias=False,
602
+ )
603
+
604
+ self.head_dim = config.head_dim
605
+ self.num_q_heads = q_heads
606
+ self.num_k_heads = k_heads
607
+ self.num_v_heads = v_heads
608
+ self.transformer_dim = config.model_dim
609
+ self.num_groups = self.num_q_heads // self.num_k_heads
610
+
611
+ def extra_repr(self) -> str:
612
+ return (
613
+ super().extra_repr()
614
+ + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
615
+ )
616
+
617
+ def forward(
618
+ self,
619
+ hidden_states: torch.Tensor,
620
+ attention_mask: Optional[torch.Tensor] = None,
621
+ past_key_value: Optional[Cache] = None,
622
+ output_attentions: bool = False,
623
+ use_cache: bool = False,
624
+ cache_position: Optional[torch.LongTensor] = None,
625
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
626
+ """
627
+ Forward pass of multi-head self-attention.
628
+ Args:
629
+ hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
630
+ past_key_value: Tensor storing the cached keys and values.
631
+ output_attentions: output attention weights.
632
+ use_cache: Specifies whether to use kv-cache for generation.
633
+ cache_position: used for updating the kv-cache.
634
+ Returns:
635
+ The output of the same shape as the input, optionally with a tensor containing cached keys and values.
636
+ """
637
+
638
+ # scaled_dot_product_attention does not return attention weights, set output_attentions to False
639
+ output_attentions = False
640
+ batch_size, seq_length, d_model = hidden_states.size()
641
+
642
+ # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
643
+ qkv = self.qkv_proj(hidden_states)
644
+ # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
645
+ qkv = qkv.reshape(
646
+ batch_size,
647
+ seq_length,
648
+ self.num_q_heads + self.num_k_heads + self.num_v_heads,
649
+ self.head_dim,
650
+ )
651
+ # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
652
+ qkv = qkv.transpose(1, 2)
653
+ # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
654
+ queries, keys, values = qkv.split(
655
+ [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
656
+ )
657
+
658
+ if self.q_norm is not None:
659
+ queries = self.q_norm(queries)
660
+
661
+ if self.k_norm is not None:
662
+ keys = self.k_norm(keys)
663
+
664
+ past_key_value = getattr(self, "past_key_value", past_key_value)
665
+
666
+ if past_key_value is not None:
667
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
668
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
669
+ cache_kwargs = {"cache_position": cache_position}
670
+ keys, values = past_key_value.update(
671
+ keys, values, self.layer_idx, cache_kwargs
672
+ )
673
+
674
+ # Add positional embedding
675
+ queries, keys = self.pos_embedding(queries, keys)
676
+
677
+ if self.num_groups != 1:
678
+ # GQA
679
+ # [B, k_h, S, h] --> [B, q_h, S, h]
680
+ keys = keys.repeat_interleave(self.num_groups, dim=1)
681
+ # [B, v_h, S, h] --> [B, q_h, S, h]
682
+ values = values.repeat_interleave(self.num_groups, dim=1)
683
+
684
+ causal_mask = attention_mask
685
+ if attention_mask is not None and cache_position is not None:
686
+ causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
687
+
688
+ attn_output = F.scaled_dot_product_attention(
689
+ queries,
690
+ keys,
691
+ values,
692
+ attn_mask=causal_mask,
693
+ dropout_p=0,
694
+ )
695
+
696
+ attn_output = attn_output.transpose(1, 2).contiguous()
697
+ attn_output = attn_output.reshape(
698
+ batch_size, seq_length, self.num_q_heads * self.head_dim
699
+ )
700
+ attn_output = self.out_proj(attn_output)
701
+ if not output_attentions:
702
+ attn_weights = None
703
+ return attn_output, attn_weights, past_key_value
704
+
705
+
706
+ class OpenELMFeedForwardNetwork(nn.Module):
707
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
708
+ super().__init__()
709
+ ffn_multiplier = config.ffn_multipliers[layer_idx]
710
+ intermediate_dim = int(
711
+ make_divisible(
712
+ ffn_multiplier * config.model_dim,
713
+ divisor=config.ffn_dim_divisor,
714
+ )
715
+ )
716
+ if config.ffn_with_glu:
717
+ # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
718
+ self.proj_1 = nn.Linear(
719
+ in_features=config.model_dim,
720
+ out_features=2 * intermediate_dim,
721
+ bias=False,
722
+ )
723
+ self.proj_2 = nn.Linear(
724
+ in_features=intermediate_dim,
725
+ out_features=config.model_dim,
726
+ bias=False,
727
+ )
728
+ self.ffn_with_glu = True
729
+ else:
730
+ # Standard FFN, as described in https://arxiv.org/abs/1706.03762
731
+ self.proj_1 = nn.Linear(
732
+ in_features=config.model_dim,
733
+ out_features=intermediate_dim,
734
+ bias=False,
735
+ )
736
+ self.proj_2 = nn.Linear(
737
+ in_features=intermediate_dim,
738
+ out_features=config.model_dim,
739
+ bias=False,
740
+ )
741
+ self.ffn_with_glu = False
742
+
743
+ self.act = ACT2FN[config.activation_fn_name]
744
+
745
+ def extra_repr(self) -> str:
746
+ return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
747
+
748
+ def forward(self, x: Tensor) -> Tensor:
749
+ """Forward function of FFN layer.
750
+ Args:
751
+ x: Input tensor of the shape [batch size, sequence length, model dimension].
752
+ Returns:
753
+ A tensor of the same shape as the input.
754
+ """
755
+ if self.ffn_with_glu:
756
+ y_12 = self.proj_1(x)
757
+ y_1, y_2 = y_12.chunk(2, dim=-1)
758
+ y = self.act(y_1) * y_2
759
+ return self.proj_2(y)
760
+ else:
761
+ return self.proj_2(self.act(self.proj_1(x)))
762
+
763
+
764
+ class OpenELMDecoderLayer(nn.Module):
765
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
766
+ super().__init__()
767
+ self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
768
+ self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
769
+ self.ffn_norm = OpenELMRMSNorm(
770
+ num_features=config.model_dim,
771
+ )
772
+ self.attn_norm = OpenELMRMSNorm(
773
+ num_features=config.model_dim,
774
+ )
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.LongTensor] = None,
781
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
+ output_attentions: Optional[bool] = False,
783
+ use_cache: Optional[bool] = False,
784
+ cache_position: Optional[torch.LongTensor] = None,
785
+ **kwargs,
786
+ ) -> Tuple[
787
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
788
+ ]:
789
+ """
790
+ Args:
791
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
792
+ attention_mask (`torch.FloatTensor`, *optional*):
793
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
794
+ query_sequence_length, key_sequence_length)` if default attention is used.
795
+ output_attentions (`bool`, *optional*):
796
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
797
+ returned tensors for more detail.
798
+ use_cache (`bool`, *optional*):
799
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
800
+ (see `past_key_values`).
801
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
802
+ """
803
+ residual = hidden_states
804
+ hidden_states = self.attn_norm(hidden_states)
805
+
806
+ # Self Attention
807
+ hidden_states, self_attn_weights, present_key_value = self.attn(
808
+ hidden_states=hidden_states,
809
+ attention_mask=attention_mask,
810
+ past_key_value=past_key_value,
811
+ output_attentions=output_attentions,
812
+ use_cache=use_cache,
813
+ cache_position=cache_position,
814
+ **kwargs,
815
+ )
816
+ hidden_states = residual + hidden_states
817
+
818
+ # Fully Connected
819
+ residual = hidden_states
820
+ hidden_states = self.ffn_norm(hidden_states)
821
+ hidden_states = self.ffn(hidden_states)
822
+ hidden_states = residual + hidden_states
823
+
824
+ outputs = (hidden_states,)
825
+
826
+ if output_attentions:
827
+ outputs += (self_attn_weights,)
828
+
829
+ if use_cache:
830
+ outputs += (present_key_value,)
831
+
832
+ return outputs
833
+
834
+
835
+ class OpenELMModel(OpenELMPreTrainedModel):
836
+ config_class = OpenELMConfig
837
+
838
+ def __init__(self, config: OpenELMConfig):
839
+ super().__init__(config)
840
+ self.config = config
841
+
842
+ self.token_embeddings = nn.Embedding(
843
+ embedding_dim=config.model_dim,
844
+ num_embeddings=config.vocab_size,
845
+ )
846
+
847
+ self.layers = nn.ModuleList(
848
+ OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
849
+ for layer_idx in range(config.num_transformer_layers)
850
+ )
851
+ self.norm = OpenELMRMSNorm(num_features=config.model_dim)
852
+ if config.share_input_output_layers:
853
+ self.classifier = None
854
+ else:
855
+ self.classifier = nn.Linear(
856
+ in_features=config.model_dim,
857
+ out_features=config.vocab_size,
858
+ bias=False,
859
+ )
860
+ self.num_transformer_layers = config.num_transformer_layers
861
+ self.gradient_checkpointing = False
862
+
863
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
864
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
865
+ causal_mask = torch.full(
866
+ (config.max_context_length, config.max_context_length),
867
+ fill_value=True,
868
+ dtype=torch.bool,
869
+ )
870
+ self.register_buffer(
871
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
872
+ )
873
+
874
+ # Initialize weights and apply final processing
875
+ self.post_init()
876
+ self.reset_parameters(config=config)
877
+
878
+ def get_input_embeddings(self):
879
+ return self.token_embeddings
880
+
881
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
882
+ self.token_embeddings = new_embeddings
883
+
884
+ def reset_parameters(self, config: OpenELMConfig) -> None:
885
+ """Initialize the layers in Language Model
886
+ The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
887
+ Args:
888
+ use_megatron_std: Use standard deviation as described in Megatron-LM.
889
+ Returns:
890
+ None
891
+ """
892
+ for module in self.modules():
893
+ if isinstance(module, nn.Linear):
894
+ std = module.in_features**-0.5
895
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
896
+ if module.bias is not None:
897
+ torch.nn.init.zeros_(module.bias)
898
+ elif isinstance(module, nn.Embedding):
899
+ std = module.embedding_dim**-0.5
900
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
901
+ elif isinstance(module, OpenELMRMSNorm):
902
+ if module.weight is not None:
903
+ torch.nn.init.ones_(module.weight)
904
+ if hasattr(module, "bias") and module.bias is not None:
905
+ torch.nn.init.zeros_(module.bias)
906
+
907
+ model_dim = config.model_dim
908
+ n_layers = config.num_transformer_layers
909
+ std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
910
+ for param_name, param in self.named_parameters():
911
+ if param_name.endswith("out_proj.weight") or param_name.endswith(
912
+ "ffn.proj_2.weight"
913
+ ):
914
+ torch.nn.init.normal_(param, mean=0.0, std=std)
915
+
916
+ def forward(
917
+ self,
918
+ input_ids: torch.LongTensor = None,
919
+ attention_mask: Optional[torch.Tensor] = None,
920
+ position_ids: Optional[torch.LongTensor] = None,
921
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
922
+ inputs_embeds: Optional[torch.FloatTensor] = None,
923
+ use_cache: Optional[bool] = None,
924
+ output_attentions: Optional[bool] = None,
925
+ output_hidden_states: Optional[bool] = None,
926
+ return_dict: Optional[bool] = None,
927
+ cache_position: Optional[torch.LongTensor] = None,
928
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
929
+ output_attentions = (
930
+ output_attentions
931
+ if output_attentions is not None
932
+ else self.config.output_attentions
933
+ )
934
+ output_hidden_states = (
935
+ output_hidden_states
936
+ if output_hidden_states is not None
937
+ else self.config.output_hidden_states
938
+ )
939
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
940
+ return_dict = (
941
+ return_dict if return_dict is not None else self.config.use_return_dict
942
+ )
943
+
944
+ if (input_ids is None) ^ (inputs_embeds is not None):
945
+ raise ValueError(
946
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
947
+ )
948
+
949
+ if self.gradient_checkpointing and self.training and use_cache:
950
+ logger.warning_once(
951
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
952
+ )
953
+ use_cache = False
954
+
955
+ if inputs_embeds is None:
956
+ inputs_embeds = self.token_embeddings(input_ids)
957
+
958
+ past_seen_tokens = 0
959
+ if use_cache: # kept for BC (cache positions)
960
+ if not isinstance(past_key_values, StaticCache):
961
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
962
+ past_seen_tokens = past_key_values.get_seq_length()
963
+
964
+ if cache_position is None:
965
+ cache_position = torch.arange(
966
+ past_seen_tokens,
967
+ past_seen_tokens + inputs_embeds.shape[1],
968
+ device=inputs_embeds.device,
969
+ )
970
+
971
+ if position_ids is None:
972
+ position_ids = cache_position.unsqueeze(0)
973
+
974
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
975
+
976
+ # embed positions
977
+ hidden_states = inputs_embeds
978
+
979
+ # decoder layers
980
+ all_hidden_states = () if output_hidden_states else None
981
+ all_self_attns = () if output_attentions else None
982
+ next_decoder_cache = None
983
+
984
+ for decoder_layer in self.layers:
985
+ if output_hidden_states:
986
+ all_hidden_states += (hidden_states,)
987
+
988
+ if self.gradient_checkpointing and self.training:
989
+ layer_outputs = self._gradient_checkpointing_func(
990
+ decoder_layer.__call__,
991
+ hidden_states,
992
+ causal_mask,
993
+ position_ids,
994
+ past_key_values,
995
+ output_attentions,
996
+ use_cache,
997
+ cache_position,
998
+ )
999
+ else:
1000
+ layer_outputs = decoder_layer(
1001
+ hidden_states,
1002
+ attention_mask=causal_mask,
1003
+ position_ids=position_ids,
1004
+ past_key_value=past_key_values,
1005
+ output_attentions=output_attentions,
1006
+ use_cache=use_cache,
1007
+ cache_position=cache_position,
1008
+ )
1009
+
1010
+ hidden_states = layer_outputs[0]
1011
+
1012
+ if use_cache:
1013
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1014
+
1015
+ if output_attentions:
1016
+ all_self_attns += (layer_outputs[1],)
1017
+
1018
+ hidden_states = self.norm(hidden_states)
1019
+
1020
+ # add hidden states from the last decoder layer
1021
+ if output_hidden_states:
1022
+ all_hidden_states += (hidden_states,)
1023
+
1024
+ next_cache = None
1025
+ if use_cache:
1026
+ next_cache = (
1027
+ next_decoder_cache.to_legacy_cache()
1028
+ if isinstance(next_decoder_cache, Cache)
1029
+ else next_decoder_cache
1030
+ )
1031
+ if not return_dict:
1032
+ return tuple(
1033
+ v
1034
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1035
+ if v is not None
1036
+ )
1037
+ return BaseModelOutputWithPast(
1038
+ last_hidden_state=hidden_states,
1039
+ past_key_values=next_cache,
1040
+ hidden_states=all_hidden_states,
1041
+ attentions=all_self_attns,
1042
+ )
1043
+
1044
+ def _update_causal_mask(self, attention_mask, input_tensor):
1045
+ if self.config._attn_implementation == "flash_attention_2":
1046
+ if attention_mask is not None and 0.0 in attention_mask:
1047
+ return attention_mask
1048
+ return None
1049
+
1050
+ batch_size, seq_length = input_tensor.shape[:2]
1051
+ dtype = input_tensor.dtype
1052
+ device = input_tensor.device
1053
+
1054
+ # support going beyond cached `max_position_embedding`
1055
+ if seq_length > self.causal_mask.shape[-1]:
1056
+ causal_mask = torch.full(
1057
+ (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
1058
+ fill_value=1,
1059
+ )
1060
+ self.register_buffer(
1061
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
1062
+ )
1063
+
1064
+ # We use the current dtype to avoid any overflows
1065
+ min_dtype = torch.finfo(dtype).min
1066
+ causal_mask = (
1067
+ self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
1068
+ * min_dtype
1069
+ )
1070
+
1071
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
1072
+ if attention_mask is not None and attention_mask.dim() == 2:
1073
+ mask_length = attention_mask.shape[-1]
1074
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1075
+ :, None, None, :
1076
+ ].eq(0.0)
1077
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
1078
+ padding_mask, min_dtype
1079
+ )
1080
+
1081
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1082
+ # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1083
+ is_tracing = (
1084
+ torch.jit.is_tracing()
1085
+ or isinstance(input_tensor, torch.fx.Proxy)
1086
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1087
+ )
1088
+ if not is_tracing and torch.any(attention_mask != 1):
1089
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1090
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1091
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1092
+ causal_mask = causal_mask.mul(
1093
+ ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
1094
+ ).to(dtype)
1095
+
1096
+ return causal_mask
1097
+
1098
+
1099
+ class OpenELMForCausalLM(OpenELMPreTrainedModel):
1100
+ _tied_weights_keys = ["lm_head.weight"]
1101
+
1102
+ def __init__(self, config: OpenELMConfig):
1103
+ super().__init__(config)
1104
+ self.transformer = OpenELMModel(config)
1105
+ self.vocab_size = config.vocab_size
1106
+ if config.share_input_output_layers:
1107
+ self.lm_head = None
1108
+ else:
1109
+ self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
1110
+
1111
+ # Initialize weights and apply final processing
1112
+ self.post_init()
1113
+
1114
+ def get_input_embeddings(self):
1115
+ return self.transformer.token_embeddings
1116
+
1117
+ def set_input_embeddings(self, value):
1118
+ self.transformer.token_embeddings = value
1119
+
1120
+ def get_output_embeddings(self):
1121
+ return self.lm_head
1122
+
1123
+ def set_output_embeddings(self, new_embeddings):
1124
+ self.lm_head = new_embeddings
1125
+
1126
+ def set_decoder(self, decoder):
1127
+ self.transformer = decoder
1128
+
1129
+ def get_decoder(self):
1130
+ return self.transformer
1131
+
1132
+ def forward(
1133
+ self,
1134
+ input_ids: torch.LongTensor = None,
1135
+ attention_mask: Optional[torch.Tensor] = None,
1136
+ position_ids: Optional[torch.LongTensor] = None,
1137
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1138
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1139
+ labels: Optional[torch.LongTensor] = None,
1140
+ use_cache: Optional[bool] = None,
1141
+ output_attentions: Optional[bool] = None,
1142
+ output_hidden_states: Optional[bool] = None,
1143
+ return_dict: Optional[bool] = None,
1144
+ cache_position: Optional[torch.LongTensor] = None,
1145
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1146
+ output_attentions = (
1147
+ output_attentions
1148
+ if output_attentions is not None
1149
+ else self.config.output_attentions
1150
+ )
1151
+ output_hidden_states = (
1152
+ output_hidden_states
1153
+ if output_hidden_states is not None
1154
+ else self.config.output_hidden_states
1155
+ )
1156
+ return_dict = (
1157
+ return_dict if return_dict is not None else self.config.use_return_dict
1158
+ )
1159
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1160
+ outputs = self.transformer(
1161
+ input_ids=input_ids,
1162
+ attention_mask=attention_mask,
1163
+ position_ids=position_ids,
1164
+ past_key_values=past_key_values,
1165
+ inputs_embeds=inputs_embeds,
1166
+ use_cache=use_cache,
1167
+ output_attentions=output_attentions,
1168
+ output_hidden_states=output_hidden_states,
1169
+ return_dict=return_dict,
1170
+ cache_position=cache_position,
1171
+ )
1172
+
1173
+ hidden_states = outputs[0]
1174
+ if self.lm_head is None:
1175
+ # shared
1176
+ logits = F.linear(
1177
+ hidden_states, weight=self.transformer.token_embeddings.weight
1178
+ )
1179
+ else:
1180
+ logits = self.lm_head(hidden_states)
1181
+ logits = logits[:, : self.config.vocab_size]
1182
+ loss = None
1183
+ if labels is not None:
1184
+ # Shift so that tokens < n predict n
1185
+ shift_logits = logits[..., :-1, :].contiguous()
1186
+ shift_labels = labels[..., 1:].contiguous()
1187
+ # Flatten the tokens
1188
+ loss_fct = CrossEntropyLoss()
1189
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1190
+ shift_labels = shift_labels.view(-1)
1191
+ # Enable model parallelism
1192
+ shift_labels = shift_labels.to(shift_logits.device)
1193
+ loss = loss_fct(shift_logits, shift_labels)
1194
+
1195
+ if not return_dict:
1196
+ output = (logits,) + outputs[1:]
1197
+ return (loss,) + output if loss is not None else output
1198
+
1199
+ return CausalLMOutputWithPast(
1200
+ loss=loss,
1201
+ logits=logits,
1202
+ past_key_values=outputs.past_key_values,
1203
+ hidden_states=outputs.hidden_states,
1204
+ attentions=outputs.attentions,
1205
+ )
1206
+
1207
+ def prepare_inputs_for_generation(
1208
+ self,
1209
+ input_ids,
1210
+ past_key_values=None,
1211
+ attention_mask=None,
1212
+ inputs_embeds=None,
1213
+ **kwargs,
1214
+ ):
1215
+ past_length = 0
1216
+ if past_key_values is not None:
1217
+ if isinstance(past_key_values, Cache):
1218
+ cache_length = past_key_values.get_seq_length()
1219
+ past_length = past_key_values.seen_tokens
1220
+ max_cache_length = past_key_values.get_max_length()
1221
+ else:
1222
+ cache_length = past_length = past_key_values[0][0].shape[2]
1223
+ max_cache_length = None
1224
+
1225
+ # Keep only the unprocessed tokens:
1226
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1227
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1228
+ # input)
1229
+ if (
1230
+ attention_mask is not None
1231
+ and attention_mask.shape[1] > input_ids.shape[1]
1232
+ ):
1233
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1234
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1235
+ # input_ids based on the past_length.
1236
+ elif past_length < input_ids.shape[1]:
1237
+ input_ids = input_ids[:, past_length:]
1238
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1239
+
1240
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1241
+ if (
1242
+ max_cache_length is not None
1243
+ and attention_mask is not None
1244
+ and cache_length + input_ids.shape[1] > max_cache_length
1245
+ ):
1246
+ attention_mask = attention_mask[:, -max_cache_length:]
1247
+
1248
+ position_ids = kwargs.get("position_ids", None)
1249
+ if attention_mask is not None and position_ids is None:
1250
+ # create position_ids on the fly for batch generation
1251
+ position_ids = attention_mask.long().cumsum(-1) - 1
1252
+ position_ids.masked_fill_(attention_mask == 0, 1)
1253
+ if past_key_values:
1254
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1255
+
1256
+ if self.generation_config.cache_implementation == "static":
1257
+ # generation with static cache
1258
+ cache_position = kwargs.get("cache_position", None)
1259
+ if cache_position is None:
1260
+ past_length = 0
1261
+ else:
1262
+ past_length = cache_position[-1] + 1
1263
+ input_ids = input_ids[:, past_length:]
1264
+ position_ids = position_ids[:, past_length:]
1265
+
1266
+ # we should only keep a `cache_position` in generate, and do +=1.
1267
+ # same goes for position ids. Could also help with continued generation.
1268
+ cache_position = torch.arange(
1269
+ past_length,
1270
+ past_length + position_ids.shape[-1],
1271
+ device=position_ids.device,
1272
+ )
1273
+
1274
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1275
+ if inputs_embeds is not None and past_key_values is None:
1276
+ model_inputs = {"inputs_embeds": inputs_embeds}
1277
+ else:
1278
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1279
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1280
+ # We could use `next_tokens` directly instead.
1281
+ model_inputs = {"input_ids": input_ids.contiguous()}
1282
+
1283
+ model_inputs.update(
1284
+ {
1285
+ "position_ids": position_ids.contiguous(),
1286
+ "cache_position": cache_position,
1287
+ "past_key_values": past_key_values,
1288
+ "use_cache": kwargs.get("use_cache"),
1289
+ "attention_mask": attention_mask,
1290
+ }
1291
+ )
1292
+ return model_inputs
1293
+
1294
+ @staticmethod
1295
+ def _reorder_cache(past_key_values, beam_idx):
1296
+ reordered_past = ()
1297
+ for layer_past in past_key_values:
1298
+ reordered_past += (
1299
+ tuple(
1300
+ past_state.index_select(0, beam_idx.to(past_state.device))
1301
+ for past_state in layer_past
1302
+ ),
1303
+ )
1304
+ return reordered_past
1305
+
1306
 
1307
  ACT_TYPE = {
1308
  'relu': nn.ReLU,
 
1690
 
1691
 
1692
 
1693
+ AutoConfig.register("tinyllava", TinyLlavaConfig)
1694
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)