jiajunlong commited on
Commit
1c7edc5
1 Parent(s): 3a0a50a

Delete modeling_elm.py

Browse files
Files changed (1) hide show
  1. modeling_elm.py +0 -1288
modeling_elm.py DELETED
@@ -1,1288 +0,0 @@
1
- #
2
- # For licensing see accompanying LICENSE file.
3
- # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
- #
5
-
6
- from typing import List, Optional, Tuple, Union
7
-
8
- import torch
9
- import torch.utils.checkpoint
10
- from torch import Tensor, nn
11
- from torch.nn import CrossEntropyLoss
12
- from torch.nn import functional as F
13
- from transformers import PreTrainedModel
14
- from transformers.activations import ACT2FN
15
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
16
- from transformers.modeling_outputs import (
17
- BaseModelOutputWithPast,
18
- CausalLMOutputWithPast,
19
- )
20
- from transformers.utils import logging
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
- # this import has to be relative, otherwise, when setting trust_remote_code=True
25
- # huggingface transformers won't be able to load the module correctly
26
- from numbers import Number
27
- from typing import List, Optional, Union
28
-
29
- import numpy as np
30
- from transformers import PretrainedConfig, AutoTokenizer
31
-
32
-
33
- def make_divisible(
34
- v: Union[float, int],
35
- divisor: Optional[int] = 8,
36
- min_value: Optional[Union[float, int]] = None,
37
- ) -> Union[float, int]:
38
- """
39
- This function is taken from the original tf repo.
40
- It ensures that all layers have a channel number that is divisible by the divisor
41
- It can be seen at:
42
- https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
43
- Args:
44
- v: input value
45
- divisor: default to 8
46
- min_value: minimum divisor value
47
- Returns:
48
- new_v: new divisible value
49
- """
50
- if min_value is None:
51
- min_value = divisor
52
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
53
- # Make sure that round down does not go down by more than 10%.
54
- if new_v < 0.9 * v:
55
- new_v += divisor
56
- return new_v
57
-
58
-
59
- def compute_heads(model_dim: int, head_dim: int) -> int:
60
- """Compute the number of heads.
61
- Args:
62
- model_dim: Model dimension.
63
- head_dim: Head dimension.
64
- Returns:
65
- An integer denoting number of heads in multi-head attention is returned.
66
- Raises:
67
- ValueError: if model dimension is not divisible by head dimension.
68
- """
69
- if model_dim % head_dim == 0:
70
- return model_dim // head_dim
71
- else:
72
- raise ValueError(
73
- f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
74
- )
75
-
76
-
77
- OpenELM_CONFIGS = {
78
- "OpenELM-270M": dict(
79
- num_transformer_layers=16,
80
- model_dim=1280,
81
- head_dim=64,
82
- num_gqa_groups=4,
83
- normalize_qk_projections=True,
84
- share_input_output_layers=True,
85
- # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
86
- ffn_multipliers=(0.5, 4.0),
87
- qkv_multipliers=(0.5, 1.0),
88
- ),
89
- "OpenELM-450M": dict(
90
- num_transformer_layers=20,
91
- model_dim=1536,
92
- head_dim=64,
93
- num_gqa_groups=4,
94
- normalize_qk_projections=True,
95
- share_input_output_layers=True,
96
- # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
97
- ffn_multipliers=(0.5, 4.0),
98
- qkv_multipliers=(0.5, 1.0),
99
- ),
100
- "OpenELM-1_1B": dict(
101
- num_transformer_layers=28,
102
- model_dim=2048,
103
- head_dim=64,
104
- num_gqa_groups=4,
105
- normalize_qk_projections=True,
106
- share_input_output_layers=True,
107
- # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
108
- ffn_multipliers=(0.5, 4.0),
109
- qkv_multipliers=(0.5, 1.0),
110
- ),
111
- "OpenELM-3B": dict(
112
- num_transformer_layers=36,
113
- model_dim=3072,
114
- head_dim=128,
115
- num_gqa_groups=4,
116
- normalize_qk_projections=True,
117
- share_input_output_layers=True,
118
- # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
119
- ffn_multipliers=(0.5, 4.0),
120
- qkv_multipliers=(0.5, 1.0),
121
- ),
122
- }
123
-
124
-
125
- class OpenELMConfig(PretrainedConfig):
126
- r"""
127
- This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
128
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
- documentation from [`PretrainedConfig`] for more information.
130
- Args:
131
- vocab_size (`int`, *optional*, defaults to 32000):
132
- Vocabulary size of the OpenELM model.
133
- max_context_length (`int`, *optional*, defaults to 2048):
134
- Maximum number of input tokens.
135
- num_transformer_layers (`int`, *optional*, defaults to 12):
136
- Number of hidden layers in the Transformer decoder.
137
- model_dim (`int`, *optional*, defaults to 2048):
138
- Dimension of the hidden representations.
139
- head_dim (`int`, *optional*, defaults to 128):
140
- The attention head dimension.
141
- qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
142
- If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
143
- resulting in uniform allocation of parameters.
144
- If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
145
- assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
146
- This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
147
- num_query_heads (`Union[int, None]`, *optional*, defaults to None):
148
- The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
149
- num_gqa_groups (`int`, *optional*, defaults to 1):
150
- This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
151
- When num_gqa_groups == 1, then it is multi-head attention.
152
- When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
153
- When num_gqa_groups == num_heads, then it is multi-query attention
154
- ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
155
- Feed-forward network (FFN) multipliers.
156
- If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
157
- resulting in uniform allocation of parameters.
158
- If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
159
- assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
160
- This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
161
- ffn_with_glu (`bool`, *optional*, defaults to True):
162
- Whether to use FFN with Gated Linear Unit (GLU)
163
- ffn_dim_divisor (`int`, *optional*, defaults to 256):
164
- The ffn layer dimension divisor.
165
- activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
166
- The non-linear activation function (function or string) in the decoder.
167
- normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
168
- Type of normalization layer.
169
- normalize_qk_projections (`bool`, *optional*, defaults to False):
170
- Whether to normalize queries and keys after projections
171
- share_input_output_layers (`bool`, *optional*, defaults to False):
172
- Whether to share the embedding between input and output linear layer
173
- rope_freq_constant (`int`, *optional*, defaults to 10000):
174
- The base period of the RoPE embeddings.
175
- rope_max_length (`int`, *optional*, defaults to 4096):
176
- That rope_max_length is set to twice of max_context_length.
177
- This allows flexibility in token lengths during training or fine-tuning.
178
- initializer_range (`float`, *optional*, defaults to 0.02):
179
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
180
- use_cache (`bool`, *optional*, defaults to `True`):
181
- Whether or not the model should return the last key/values attentions (not used by all models). Only
182
- relevant if `config.is_decoder=True`.
183
- bos_token_id (`int`, *optional*, defaults to 2):
184
- Beginning of stream token id.
185
- eos_token_id (`int`, *optional*, defaults to 1):
186
- End of stream token id.
187
- """
188
-
189
- model_type = "openelm"
190
-
191
- def __init__(
192
- self,
193
- vocab_size: int = 32000,
194
- max_context_length: int = 2048,
195
- num_transformer_layers: int = 12,
196
- model_dim: int = 2048,
197
- head_dim: int = 128,
198
- qkv_multipliers: Union[Number, List[Number]] = 1.0,
199
- num_query_heads: Union[int, None] = None,
200
- num_gqa_groups: int = 1,
201
- ffn_multipliers: Union[Number, List[Number]] = 4.0,
202
- ffn_with_glu: bool = True,
203
- ffn_dim_divisor: int = 256,
204
- activation_fn_name: str = "swish",
205
- normalization_layer_name: str = "rms_norm",
206
- normalize_qk_projections: bool = False,
207
- share_input_output_layers: bool = False,
208
- rope_freq_constant: int = 10000,
209
- rope_max_length: int = 4096,
210
- initializer_range: float = 0.02,
211
- use_cache: bool = True,
212
- bos_token_id: int = 1,
213
- eos_token_id: int = 2,
214
- **kwargs,
215
- ) -> None:
216
- self.vocab_size = vocab_size
217
- self.max_context_length = max_context_length
218
- self.num_transformer_layers = num_transformer_layers
219
- self.model_dim = model_dim
220
- self.head_dim = head_dim
221
- self.qkv_multipliers = qkv_multipliers
222
- self.num_query_heads = num_query_heads
223
- self.num_gqa_groups = num_gqa_groups
224
- self.ffn_multipliers = ffn_multipliers
225
- self.ffn_with_glu = ffn_with_glu
226
- self.ffn_dim_divisor = ffn_dim_divisor
227
- self.activation_fn_name = activation_fn_name
228
- self.normalization_layer_name = normalization_layer_name
229
- self.normalize_qk_projections = normalize_qk_projections
230
- self.share_input_output_layers = share_input_output_layers
231
- self.rope_freq_constant = rope_freq_constant
232
- self.rope_max_length = rope_max_length
233
- self.num_query_heads = (
234
- compute_heads(model_dim=model_dim, head_dim=head_dim)
235
- if num_query_heads is None
236
- else num_query_heads
237
- )
238
- self.initializer_range = initializer_range
239
-
240
- self.__post_init__()
241
- super().__init__(
242
- use_cache=use_cache,
243
- bos_token_id=bos_token_id,
244
- eos_token_id=eos_token_id,
245
- **kwargs,
246
- )
247
-
248
- def __post_init__(self) -> None:
249
- if self.num_gqa_groups is not None:
250
- head_multiple_of = self.num_gqa_groups
251
- else:
252
- head_multiple_of = 2
253
-
254
- if isinstance(self.qkv_multipliers, Number):
255
- # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
256
- qkv_dim = make_divisible(
257
- self.model_dim * self.qkv_multipliers,
258
- divisor=self.head_dim * head_multiple_of,
259
- )
260
- query_dims = [int(qkv_dim)] * self.num_transformer_layers
261
-
262
- elif (
263
- isinstance(self.qkv_multipliers, (tuple, list))
264
- and len(self.qkv_multipliers) == 2
265
- ):
266
- # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
267
- # This results in variable allocation of parameters in attention layer.
268
- # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
269
- qkv_multipliers = [
270
- round(v, 2)
271
- for v in np.linspace(
272
- self.qkv_multipliers[0],
273
- self.qkv_multipliers[1],
274
- num=self.num_transformer_layers,
275
- dtype=float,
276
- )
277
- ]
278
- # Make sure that scaled model dimension is divisible by scaled head dimension.
279
- query_dims = [
280
- int(
281
- make_divisible(
282
- self.model_dim * m, divisor=self.head_dim * head_multiple_of
283
- )
284
- )
285
- for m in qkv_multipliers
286
- ]
287
- else:
288
- raise NotImplementedError(
289
- f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
290
- )
291
-
292
- # compute the number of query, key, and value heads
293
- # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
294
- # For group query attention, the number of key and value heads are the same.
295
- self.num_query_heads = [
296
- int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
297
- ]
298
- self.num_kv_heads = [
299
- q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
300
- ]
301
-
302
- # Feed-forward network (FFN) multipliers
303
- if isinstance(self.ffn_multipliers, Number):
304
- # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
305
- self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
306
- elif isinstance(self.ffn_multipliers, (tuple, list)):
307
- # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
308
- # This results in variable allocation of parameters in FFN layer.
309
- # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
310
- if len(self.ffn_multipliers) == 2:
311
- self.ffn_multipliers = [
312
- round(v, 2)
313
- for v in np.linspace(
314
- self.ffn_multipliers[0],
315
- self.ffn_multipliers[1],
316
- num=self.num_transformer_layers,
317
- dtype=float,
318
- )
319
- ]
320
- else:
321
- assert (
322
- len(self.ffn_multipliers) == self.num_transformer_layers
323
- ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
324
- else:
325
- raise NotImplementedError(
326
- f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
327
- )
328
-
329
- # check num_query_heads divisible by num_kv_heads for every layer
330
- for layer_idx in range(len(query_dims)):
331
- assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
332
-
333
- class OpenELMRMSNorm(nn.Module):
334
- def __init__(self, num_features: int, eps: float = 1e-6):
335
- """
336
- Initialize the OpenELMRMSNorm normalization layer.
337
- Args:
338
- dim (int): The dimension of the input tensor.
339
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
340
- Attributes:
341
- eps (float): A small value added to the denominator for numerical stability.
342
- weight (nn.Parameter): Learnable scaling parameter.
343
- """
344
- super().__init__()
345
- self.eps = eps
346
- self.weight = nn.Parameter(torch.ones(num_features))
347
- self.num_features = num_features
348
-
349
- def _norm(self, x: Tensor) -> Tensor:
350
- """
351
- Apply the OpenELMRMSNorm normalization to the input tensor.
352
- Args:
353
- x (torch.Tensor): The input tensor.
354
- Returns:
355
- torch.Tensor: The normalized tensor.
356
- """
357
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
358
-
359
- def forward(self, x: Tensor) -> Tensor:
360
- """
361
- Forward pass through the OpenELMRMSNorm layer.
362
- Args:
363
- x (torch.Tensor): The input tensor.
364
- Returns:
365
- torch.Tensor: The output tensor after applying OpenELMRMSNorm.
366
- """
367
- output = self._norm(x.float()).type_as(x)
368
- return output * self.weight
369
-
370
- def extra_repr(self) -> str:
371
- return (
372
- super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
373
- )
374
-
375
-
376
- class OpenELMPreTrainedModel(PreTrainedModel):
377
- config_class = OpenELMConfig
378
- base_model_prefix = "transformer"
379
- supports_gradient_checkpointing = True
380
- _no_split_modules = ["OpenELMDecoderLayer"]
381
- _skip_keys_device_placement = "past_key_values"
382
-
383
- def __init__(self, *inputs, **kwargs) -> None:
384
- super().__init__(*inputs, **kwargs)
385
-
386
- def _init_weights(self, module: nn.Module) -> None:
387
- """Initialize the weights."""
388
- if isinstance(module, nn.Linear):
389
- # Slightly different from the TF version which uses truncated_normal for initialization
390
- # cf https://github.com/pytorch/pytorch/pull/5617
391
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
392
- if module.bias is not None:
393
- module.bias.data.zero_()
394
- elif isinstance(module, nn.Embedding):
395
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
396
- if module.padding_idx is not None:
397
- module.weight.data[module.padding_idx].zero_()
398
- elif isinstance(module, OpenELMRMSNorm):
399
- module.weight.data.fill_(1.0)
400
-
401
-
402
- def _rotate_half(x: Tensor) -> Tensor:
403
- x1, x2 = x.chunk(2, dim=-1)
404
- return torch.cat((-x2, x1), dim=-1)
405
-
406
-
407
- def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
408
- return (x * pos_cos) + (_rotate_half(x) * pos_sin)
409
-
410
-
411
- class OpenELMRotaryEmbedding(torch.nn.Module):
412
- """
413
- The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
414
- RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
415
- explicit relative positional dependencies.
416
- Args:
417
- model_dim: The dimensionality of the model's hidden state.
418
- max_seq_length: Maximum sequence length.
419
- freq_constant: A constant used for computing frequencies.
420
- """
421
-
422
- def __init__(
423
- self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
424
- ) -> None:
425
- inv_freq = 1.0 / (
426
- freq_constant
427
- ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
428
- )
429
- super().__init__()
430
-
431
- self.model_dim = model_dim
432
- self.freq_constant = freq_constant
433
- self.max_seq_length = max_seq_length
434
-
435
- self.register_buffer("inv_freq", inv_freq, persistent=False)
436
- self._cached_cos = None
437
- self._cached_sin = None
438
- self._cached_seq_length = max_seq_length
439
- self._compute_sin_cos_embeddings(max_seq_length)
440
-
441
- def extra_repr(self) -> str:
442
- return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
443
-
444
- def _compute_sin_cos_embeddings(
445
- self,
446
- key_len: int,
447
- key_device: torch.device = torch.device("cpu"),
448
- key_dtype: torch.dtype = torch.float32,
449
- ) -> None:
450
- """
451
- Compute sine and cos embeddings.
452
- Args:
453
- key_len: Number of tokens in the key embeddings in the transformer model.
454
- device: Device where the key embeddings are stored.
455
- key_dtype: Data type of the key embeddings.
456
- Returns:
457
- None
458
- ...note:
459
- We recalculate the sine and cosine embeddings if any of the following conditions are met:
460
- 1. The number of tokens in key embeddings are greater than the cached sequence length.
461
- 2. Sine and cosine caches are empty.
462
- 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
463
- """
464
- if (
465
- key_len > self._cached_seq_length
466
- or self._cached_cos is None
467
- or (self._cached_cos is not None and self._cached_cos.device != key_device)
468
- or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
469
- or self._cached_sin is None
470
- or (self._cached_sin is not None and self._cached_sin.device != key_device)
471
- or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
472
- ):
473
- self._cached_seq_length = max(key_len, self._cached_seq_length)
474
-
475
- # The shape of 'pos_index' is [number of key tokens]
476
- pos_index = torch.arange(
477
- self._cached_seq_length,
478
- dtype=torch.float32,
479
- device=self.inv_freq.device,
480
- )
481
- # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
482
- pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
483
- # The shape of 'emb' is [number of key tokens, model dimension]
484
- emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
485
-
486
- # the shape of cos and sin embeddings is [number of key tokens, model_dim]
487
- cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
488
- sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
489
-
490
- # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
491
- self._cached_cos = cos_emb[None, None, :, :]
492
- self._cached_sin = sin_emb[None, None, :, :]
493
-
494
- def forward(
495
- self,
496
- query: torch.Tensor,
497
- key: torch.Tensor,
498
- ) -> Tuple[torch.Tensor, torch.Tensor]:
499
- """
500
- The forward function of RoPE embeddings.
501
- Args:
502
- query: Query embeddings in the transformer model. The shape of query embeddings is
503
- [Batch, number of query heads, number of query tokens, model dimension].
504
- key: Key embeddings in the transformer model. The shape of key embeddings is
505
- [Batch, number of key heads, number of key tokens, model dimension].
506
- Returns:
507
- A tuple containing the query and key embeddings with positional information. The shape of the returned query
508
- and key embeddings is the same as the input query and key embeddings respectively.
509
- ...note:
510
- The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
511
- are casted to original input datatype.
512
- """
513
- dim = key.shape[-1]
514
- key_len = key.shape[2]
515
- query_len = query.shape[2]
516
-
517
- assert dim == self.model_dim
518
- assert key.device == query.device
519
- assert key.dtype == query.dtype
520
-
521
- # In the context of self-attention, the lengths of keys and queries are equal.
522
- # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
523
- # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
524
- # represent embeddings of previous tokens and the current token, while the query corresponds
525
- # to the embedding of the current token only.
526
- assert (
527
- key_len >= query_len
528
- ), "Number of keys has to be greater than or equal to number of queries."
529
-
530
- query_float = query.float()
531
- key_float = key.float()
532
-
533
- self._compute_sin_cos_embeddings(
534
- key_len, key_device=key_float.device, key_dtype=key_float.dtype
535
- )
536
- query_float = _apply_rotary_pos_emb(
537
- x=query_float,
538
- pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
539
- pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
540
- )
541
- key_float = _apply_rotary_pos_emb(
542
- x=key_float,
543
- pos_sin=self._cached_sin[..., :key_len, :],
544
- pos_cos=self._cached_cos[..., :key_len, :],
545
- )
546
-
547
- return query_float.type_as(query), key_float.type_as(key)
548
-
549
-
550
- class OpenELMMultiHeadCausalAttention(nn.Module):
551
- def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
552
- super().__init__()
553
- self.layer_idx = layer_idx
554
- head_dim = config.head_dim
555
- q_heads = config.num_query_heads[layer_idx]
556
- k_heads = config.num_kv_heads[layer_idx]
557
- v_heads = config.num_kv_heads[layer_idx]
558
-
559
- self.qkv_proj = nn.Linear(
560
- in_features=config.model_dim,
561
- out_features=(q_heads + k_heads + v_heads) * head_dim,
562
- bias=False,
563
- )
564
-
565
- self.pos_embedding = OpenELMRotaryEmbedding(
566
- model_dim=config.head_dim,
567
- max_seq_length=config.rope_max_length,
568
- freq_constant=config.rope_freq_constant,
569
- )
570
-
571
- if config.normalize_qk_projections:
572
- self.q_norm = OpenELMRMSNorm(
573
- num_features=config.head_dim,
574
- )
575
- self.k_norm = OpenELMRMSNorm(
576
- num_features=config.head_dim,
577
- )
578
- else:
579
- self.q_norm = None
580
- self.k_norm = None
581
-
582
- self.out_proj = nn.Linear(
583
- in_features=q_heads * head_dim,
584
- out_features=config.model_dim,
585
- bias=False,
586
- )
587
-
588
- self.head_dim = config.head_dim
589
- self.num_q_heads = q_heads
590
- self.num_k_heads = k_heads
591
- self.num_v_heads = v_heads
592
- self.transformer_dim = config.model_dim
593
- self.num_groups = self.num_q_heads // self.num_k_heads
594
-
595
- def extra_repr(self) -> str:
596
- return (
597
- super().extra_repr()
598
- + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
599
- )
600
-
601
- def forward(
602
- self,
603
- hidden_states: torch.Tensor,
604
- attention_mask: Optional[torch.Tensor] = None,
605
- past_key_value: Optional[Cache] = None,
606
- output_attentions: bool = False,
607
- use_cache: bool = False,
608
- cache_position: Optional[torch.LongTensor] = None,
609
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
610
- """
611
- Forward pass of multi-head self-attention.
612
- Args:
613
- hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
614
- past_key_value: Tensor storing the cached keys and values.
615
- output_attentions: output attention weights.
616
- use_cache: Specifies whether to use kv-cache for generation.
617
- cache_position: used for updating the kv-cache.
618
- Returns:
619
- The output of the same shape as the input, optionally with a tensor containing cached keys and values.
620
- """
621
-
622
- # scaled_dot_product_attention does not return attention weights, set output_attentions to False
623
- output_attentions = False
624
- batch_size, seq_length, d_model = hidden_states.size()
625
-
626
- # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
627
- qkv = self.qkv_proj(hidden_states)
628
- # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
629
- qkv = qkv.reshape(
630
- batch_size,
631
- seq_length,
632
- self.num_q_heads + self.num_k_heads + self.num_v_heads,
633
- self.head_dim,
634
- )
635
- # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
636
- qkv = qkv.transpose(1, 2)
637
- # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
638
- queries, keys, values = qkv.split(
639
- [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
640
- )
641
-
642
- if self.q_norm is not None:
643
- queries = self.q_norm(queries)
644
-
645
- if self.k_norm is not None:
646
- keys = self.k_norm(keys)
647
-
648
- past_key_value = getattr(self, "past_key_value", past_key_value)
649
-
650
- if past_key_value is not None:
651
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
652
- # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
653
- cache_kwargs = {"cache_position": cache_position}
654
- keys, values = past_key_value.update(
655
- keys, values, self.layer_idx, cache_kwargs
656
- )
657
-
658
- # Add positional embedding
659
- queries, keys = self.pos_embedding(queries, keys)
660
-
661
- if self.num_groups != 1:
662
- # GQA
663
- # [B, k_h, S, h] --> [B, q_h, S, h]
664
- keys = keys.repeat_interleave(self.num_groups, dim=1)
665
- # [B, v_h, S, h] --> [B, q_h, S, h]
666
- values = values.repeat_interleave(self.num_groups, dim=1)
667
-
668
- causal_mask = attention_mask
669
- if attention_mask is not None and cache_position is not None:
670
- causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
671
-
672
- attn_output = F.scaled_dot_product_attention(
673
- queries,
674
- keys,
675
- values,
676
- attn_mask=causal_mask,
677
- dropout_p=0,
678
- )
679
-
680
- attn_output = attn_output.transpose(1, 2).contiguous()
681
- attn_output = attn_output.reshape(
682
- batch_size, seq_length, self.num_q_heads * self.head_dim
683
- )
684
- attn_output = self.out_proj(attn_output)
685
- if not output_attentions:
686
- attn_weights = None
687
- return attn_output, attn_weights, past_key_value
688
-
689
-
690
- class OpenELMFeedForwardNetwork(nn.Module):
691
- def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
692
- super().__init__()
693
- ffn_multiplier = config.ffn_multipliers[layer_idx]
694
- intermediate_dim = int(
695
- make_divisible(
696
- ffn_multiplier * config.model_dim,
697
- divisor=config.ffn_dim_divisor,
698
- )
699
- )
700
- if config.ffn_with_glu:
701
- # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
702
- self.proj_1 = nn.Linear(
703
- in_features=config.model_dim,
704
- out_features=2 * intermediate_dim,
705
- bias=False,
706
- )
707
- self.proj_2 = nn.Linear(
708
- in_features=intermediate_dim,
709
- out_features=config.model_dim,
710
- bias=False,
711
- )
712
- self.ffn_with_glu = True
713
- else:
714
- # Standard FFN, as described in https://arxiv.org/abs/1706.03762
715
- self.proj_1 = nn.Linear(
716
- in_features=config.model_dim,
717
- out_features=intermediate_dim,
718
- bias=False,
719
- )
720
- self.proj_2 = nn.Linear(
721
- in_features=intermediate_dim,
722
- out_features=config.model_dim,
723
- bias=False,
724
- )
725
- self.ffn_with_glu = False
726
-
727
- self.act = ACT2FN[config.activation_fn_name]
728
-
729
- def extra_repr(self) -> str:
730
- return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
731
-
732
- def forward(self, x: Tensor) -> Tensor:
733
- """Forward function of FFN layer.
734
- Args:
735
- x: Input tensor of the shape [batch size, sequence length, model dimension].
736
- Returns:
737
- A tensor of the same shape as the input.
738
- """
739
- if self.ffn_with_glu:
740
- y_12 = self.proj_1(x)
741
- y_1, y_2 = y_12.chunk(2, dim=-1)
742
- y = self.act(y_1) * y_2
743
- return self.proj_2(y)
744
- else:
745
- return self.proj_2(self.act(self.proj_1(x)))
746
-
747
-
748
- class OpenELMDecoderLayer(nn.Module):
749
- def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
750
- super().__init__()
751
- self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
752
- self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
753
- self.ffn_norm = OpenELMRMSNorm(
754
- num_features=config.model_dim,
755
- )
756
- self.attn_norm = OpenELMRMSNorm(
757
- num_features=config.model_dim,
758
- )
759
-
760
- def forward(
761
- self,
762
- hidden_states: torch.Tensor,
763
- attention_mask: Optional[torch.Tensor] = None,
764
- position_ids: Optional[torch.LongTensor] = None,
765
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
766
- output_attentions: Optional[bool] = False,
767
- use_cache: Optional[bool] = False,
768
- cache_position: Optional[torch.LongTensor] = None,
769
- **kwargs,
770
- ) -> Tuple[
771
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
772
- ]:
773
- """
774
- Args:
775
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
776
- attention_mask (`torch.FloatTensor`, *optional*):
777
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
778
- query_sequence_length, key_sequence_length)` if default attention is used.
779
- output_attentions (`bool`, *optional*):
780
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
781
- returned tensors for more detail.
782
- use_cache (`bool`, *optional*):
783
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
784
- (see `past_key_values`).
785
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
786
- """
787
- residual = hidden_states
788
- hidden_states = self.attn_norm(hidden_states)
789
-
790
- # Self Attention
791
- hidden_states, self_attn_weights, present_key_value = self.attn(
792
- hidden_states=hidden_states,
793
- attention_mask=attention_mask,
794
- past_key_value=past_key_value,
795
- output_attentions=output_attentions,
796
- use_cache=use_cache,
797
- cache_position=cache_position,
798
- **kwargs,
799
- )
800
- hidden_states = residual + hidden_states
801
-
802
- # Fully Connected
803
- residual = hidden_states
804
- hidden_states = self.ffn_norm(hidden_states)
805
- hidden_states = self.ffn(hidden_states)
806
- hidden_states = residual + hidden_states
807
-
808
- outputs = (hidden_states,)
809
-
810
- if output_attentions:
811
- outputs += (self_attn_weights,)
812
-
813
- if use_cache:
814
- outputs += (present_key_value,)
815
-
816
- return outputs
817
-
818
-
819
- class OpenELMModel(OpenELMPreTrainedModel):
820
- config_class = OpenELMConfig
821
-
822
- def __init__(self, config: OpenELMConfig):
823
- super().__init__(config)
824
- self.config = config
825
-
826
- self.token_embeddings = nn.Embedding(
827
- embedding_dim=config.model_dim,
828
- num_embeddings=config.vocab_size,
829
- )
830
-
831
- self.layers = nn.ModuleList(
832
- OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
833
- for layer_idx in range(config.num_transformer_layers)
834
- )
835
- self.norm = OpenELMRMSNorm(num_features=config.model_dim)
836
- if config.share_input_output_layers:
837
- self.classifier = None
838
- else:
839
- self.classifier = nn.Linear(
840
- in_features=config.model_dim,
841
- out_features=config.vocab_size,
842
- bias=False,
843
- )
844
- self.num_transformer_layers = config.num_transformer_layers
845
- self.gradient_checkpointing = False
846
-
847
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
848
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
849
- causal_mask = torch.full(
850
- (config.max_context_length, config.max_context_length),
851
- fill_value=True,
852
- dtype=torch.bool,
853
- )
854
- self.register_buffer(
855
- "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
856
- )
857
-
858
- # Initialize weights and apply final processing
859
- self.post_init()
860
- self.reset_parameters(config=config)
861
-
862
- def get_input_embeddings(self):
863
- return self.token_embeddings
864
-
865
- def set_input_embeddings(self, new_embeddings: torch.Tensor):
866
- self.token_embeddings = new_embeddings
867
-
868
- def reset_parameters(self, config: OpenELMConfig) -> None:
869
- """Initialize the layers in Language Model
870
- The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
871
- Args:
872
- use_megatron_std: Use standard deviation as described in Megatron-LM.
873
- Returns:
874
- None
875
- """
876
- for module in self.modules():
877
- if isinstance(module, nn.Linear):
878
- std = module.in_features**-0.5
879
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
880
- if module.bias is not None:
881
- torch.nn.init.zeros_(module.bias)
882
- elif isinstance(module, nn.Embedding):
883
- std = module.embedding_dim**-0.5
884
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
885
- elif isinstance(module, OpenELMRMSNorm):
886
- if module.weight is not None:
887
- torch.nn.init.ones_(module.weight)
888
- if hasattr(module, "bias") and module.bias is not None:
889
- torch.nn.init.zeros_(module.bias)
890
-
891
- model_dim = config.model_dim
892
- n_layers = config.num_transformer_layers
893
- std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
894
- for param_name, param in self.named_parameters():
895
- if param_name.endswith("out_proj.weight") or param_name.endswith(
896
- "ffn.proj_2.weight"
897
- ):
898
- torch.nn.init.normal_(param, mean=0.0, std=std)
899
-
900
- def forward(
901
- self,
902
- input_ids: torch.LongTensor = None,
903
- attention_mask: Optional[torch.Tensor] = None,
904
- position_ids: Optional[torch.LongTensor] = None,
905
- past_key_values: Optional[List[torch.FloatTensor]] = None,
906
- inputs_embeds: Optional[torch.FloatTensor] = None,
907
- use_cache: Optional[bool] = None,
908
- output_attentions: Optional[bool] = None,
909
- output_hidden_states: Optional[bool] = None,
910
- return_dict: Optional[bool] = None,
911
- cache_position: Optional[torch.LongTensor] = None,
912
- ) -> Union[Tuple, BaseModelOutputWithPast]:
913
- output_attentions = (
914
- output_attentions
915
- if output_attentions is not None
916
- else self.config.output_attentions
917
- )
918
- output_hidden_states = (
919
- output_hidden_states
920
- if output_hidden_states is not None
921
- else self.config.output_hidden_states
922
- )
923
- use_cache = use_cache if use_cache is not None else self.config.use_cache
924
- return_dict = (
925
- return_dict if return_dict is not None else self.config.use_return_dict
926
- )
927
-
928
- if (input_ids is None) ^ (inputs_embeds is not None):
929
- raise ValueError(
930
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
931
- )
932
-
933
- if self.gradient_checkpointing and self.training and use_cache:
934
- logger.warning_once(
935
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
936
- )
937
- use_cache = False
938
-
939
- if inputs_embeds is None:
940
- inputs_embeds = self.token_embeddings(input_ids)
941
-
942
- past_seen_tokens = 0
943
- if use_cache: # kept for BC (cache positions)
944
- if not isinstance(past_key_values, StaticCache):
945
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
946
- past_seen_tokens = past_key_values.get_seq_length()
947
-
948
- if cache_position is None:
949
- cache_position = torch.arange(
950
- past_seen_tokens,
951
- past_seen_tokens + inputs_embeds.shape[1],
952
- device=inputs_embeds.device,
953
- )
954
-
955
- if position_ids is None:
956
- position_ids = cache_position.unsqueeze(0)
957
-
958
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
959
-
960
- # embed positions
961
- hidden_states = inputs_embeds
962
-
963
- # decoder layers
964
- all_hidden_states = () if output_hidden_states else None
965
- all_self_attns = () if output_attentions else None
966
- next_decoder_cache = None
967
-
968
- for decoder_layer in self.layers:
969
- if output_hidden_states:
970
- all_hidden_states += (hidden_states,)
971
-
972
- if self.gradient_checkpointing and self.training:
973
- layer_outputs = self._gradient_checkpointing_func(
974
- decoder_layer.__call__,
975
- hidden_states,
976
- causal_mask,
977
- position_ids,
978
- past_key_values,
979
- output_attentions,
980
- use_cache,
981
- cache_position,
982
- )
983
- else:
984
- layer_outputs = decoder_layer(
985
- hidden_states,
986
- attention_mask=causal_mask,
987
- position_ids=position_ids,
988
- past_key_value=past_key_values,
989
- output_attentions=output_attentions,
990
- use_cache=use_cache,
991
- cache_position=cache_position,
992
- )
993
-
994
- hidden_states = layer_outputs[0]
995
-
996
- if use_cache:
997
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
998
-
999
- if output_attentions:
1000
- all_self_attns += (layer_outputs[1],)
1001
-
1002
- hidden_states = self.norm(hidden_states)
1003
-
1004
- # add hidden states from the last decoder layer
1005
- if output_hidden_states:
1006
- all_hidden_states += (hidden_states,)
1007
-
1008
- next_cache = None
1009
- if use_cache:
1010
- next_cache = (
1011
- next_decoder_cache.to_legacy_cache()
1012
- if isinstance(next_decoder_cache, Cache)
1013
- else next_decoder_cache
1014
- )
1015
- if not return_dict:
1016
- return tuple(
1017
- v
1018
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1019
- if v is not None
1020
- )
1021
- return BaseModelOutputWithPast(
1022
- last_hidden_state=hidden_states,
1023
- past_key_values=next_cache,
1024
- hidden_states=all_hidden_states,
1025
- attentions=all_self_attns,
1026
- )
1027
-
1028
- def _update_causal_mask(self, attention_mask, input_tensor):
1029
- if self.config._attn_implementation == "flash_attention_2":
1030
- if attention_mask is not None and 0.0 in attention_mask:
1031
- return attention_mask
1032
- return None
1033
-
1034
- batch_size, seq_length = input_tensor.shape[:2]
1035
- dtype = input_tensor.dtype
1036
- device = input_tensor.device
1037
-
1038
- # support going beyond cached `max_position_embedding`
1039
- if seq_length > self.causal_mask.shape[-1]:
1040
- causal_mask = torch.full(
1041
- (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
1042
- fill_value=1,
1043
- )
1044
- self.register_buffer(
1045
- "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
1046
- )
1047
-
1048
- # We use the current dtype to avoid any overflows
1049
- min_dtype = torch.finfo(dtype).min
1050
- causal_mask = (
1051
- self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
1052
- * min_dtype
1053
- )
1054
-
1055
- causal_mask = causal_mask.to(dtype=dtype, device=device)
1056
- if attention_mask is not None and attention_mask.dim() == 2:
1057
- mask_length = attention_mask.shape[-1]
1058
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1059
- :, None, None, :
1060
- ].eq(0.0)
1061
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
1062
- padding_mask, min_dtype
1063
- )
1064
-
1065
- if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1066
- # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1067
- is_tracing = (
1068
- torch.jit.is_tracing()
1069
- or isinstance(input_tensor, torch.fx.Proxy)
1070
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1071
- )
1072
- if not is_tracing and torch.any(attention_mask != 1):
1073
- # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1074
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1075
- # Details: https://github.com/pytorch/pytorch/issues/110213
1076
- causal_mask = causal_mask.mul(
1077
- ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
1078
- ).to(dtype)
1079
-
1080
- return causal_mask
1081
-
1082
-
1083
- class OpenELMForCausalLM(OpenELMPreTrainedModel):
1084
- _tied_weights_keys = ["lm_head.weight"]
1085
-
1086
- def __init__(self, config: OpenELMConfig):
1087
- super().__init__(config)
1088
- self.transformer = OpenELMModel(config)
1089
- self.vocab_size = config.vocab_size
1090
- if config.share_input_output_layers:
1091
- self.lm_head = None
1092
- else:
1093
- self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
1094
-
1095
- # Initialize weights and apply final processing
1096
- self.post_init()
1097
-
1098
- def get_input_embeddings(self):
1099
- return self.transformer.token_embeddings
1100
-
1101
- def set_input_embeddings(self, value):
1102
- self.transformer.token_embeddings = value
1103
-
1104
- def get_output_embeddings(self):
1105
- return self.lm_head
1106
-
1107
- def set_output_embeddings(self, new_embeddings):
1108
- self.lm_head = new_embeddings
1109
-
1110
- def set_decoder(self, decoder):
1111
- self.transformer = decoder
1112
-
1113
- def get_decoder(self):
1114
- return self.transformer
1115
-
1116
- def forward(
1117
- self,
1118
- input_ids: torch.LongTensor = None,
1119
- attention_mask: Optional[torch.Tensor] = None,
1120
- position_ids: Optional[torch.LongTensor] = None,
1121
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1122
- inputs_embeds: Optional[torch.FloatTensor] = None,
1123
- labels: Optional[torch.LongTensor] = None,
1124
- use_cache: Optional[bool] = None,
1125
- output_attentions: Optional[bool] = None,
1126
- output_hidden_states: Optional[bool] = None,
1127
- return_dict: Optional[bool] = None,
1128
- cache_position: Optional[torch.LongTensor] = None,
1129
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1130
- output_attentions = (
1131
- output_attentions
1132
- if output_attentions is not None
1133
- else self.config.output_attentions
1134
- )
1135
- output_hidden_states = (
1136
- output_hidden_states
1137
- if output_hidden_states is not None
1138
- else self.config.output_hidden_states
1139
- )
1140
- return_dict = (
1141
- return_dict if return_dict is not None else self.config.use_return_dict
1142
- )
1143
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1144
- outputs = self.transformer(
1145
- input_ids=input_ids,
1146
- attention_mask=attention_mask,
1147
- position_ids=position_ids,
1148
- past_key_values=past_key_values,
1149
- inputs_embeds=inputs_embeds,
1150
- use_cache=use_cache,
1151
- output_attentions=output_attentions,
1152
- output_hidden_states=output_hidden_states,
1153
- return_dict=return_dict,
1154
- cache_position=cache_position,
1155
- )
1156
-
1157
- hidden_states = outputs[0]
1158
- if self.lm_head is None:
1159
- # shared
1160
- logits = F.linear(
1161
- hidden_states, weight=self.transformer.token_embeddings.weight
1162
- )
1163
- else:
1164
- logits = self.lm_head(hidden_states)
1165
- logits = logits[:, : self.config.vocab_size]
1166
- loss = None
1167
- if labels is not None:
1168
- # Shift so that tokens < n predict n
1169
- shift_logits = logits[..., :-1, :].contiguous()
1170
- shift_labels = labels[..., 1:].contiguous()
1171
- # Flatten the tokens
1172
- loss_fct = CrossEntropyLoss()
1173
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1174
- shift_labels = shift_labels.view(-1)
1175
- # Enable model parallelism
1176
- shift_labels = shift_labels.to(shift_logits.device)
1177
- loss = loss_fct(shift_logits, shift_labels)
1178
-
1179
- if not return_dict:
1180
- output = (logits,) + outputs[1:]
1181
- return (loss,) + output if loss is not None else output
1182
-
1183
- return CausalLMOutputWithPast(
1184
- loss=loss,
1185
- logits=logits,
1186
- past_key_values=outputs.past_key_values,
1187
- hidden_states=outputs.hidden_states,
1188
- attentions=outputs.attentions,
1189
- )
1190
-
1191
- def prepare_inputs_for_generation(
1192
- self,
1193
- input_ids,
1194
- past_key_values=None,
1195
- attention_mask=None,
1196
- inputs_embeds=None,
1197
- **kwargs,
1198
- ):
1199
- past_length = 0
1200
- if past_key_values is not None:
1201
- if isinstance(past_key_values, Cache):
1202
- cache_length = past_key_values.get_seq_length()
1203
- past_length = past_key_values.seen_tokens
1204
- max_cache_length = past_key_values.get_max_length()
1205
- else:
1206
- cache_length = past_length = past_key_values[0][0].shape[2]
1207
- max_cache_length = None
1208
-
1209
- # Keep only the unprocessed tokens:
1210
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1211
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1212
- # input)
1213
- if (
1214
- attention_mask is not None
1215
- and attention_mask.shape[1] > input_ids.shape[1]
1216
- ):
1217
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1218
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1219
- # input_ids based on the past_length.
1220
- elif past_length < input_ids.shape[1]:
1221
- input_ids = input_ids[:, past_length:]
1222
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1223
-
1224
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1225
- if (
1226
- max_cache_length is not None
1227
- and attention_mask is not None
1228
- and cache_length + input_ids.shape[1] > max_cache_length
1229
- ):
1230
- attention_mask = attention_mask[:, -max_cache_length:]
1231
-
1232
- position_ids = kwargs.get("position_ids", None)
1233
- if attention_mask is not None and position_ids is None:
1234
- # create position_ids on the fly for batch generation
1235
- position_ids = attention_mask.long().cumsum(-1) - 1
1236
- position_ids.masked_fill_(attention_mask == 0, 1)
1237
- if past_key_values:
1238
- position_ids = position_ids[:, -input_ids.shape[1] :]
1239
-
1240
- if self.generation_config.cache_implementation == "static":
1241
- # generation with static cache
1242
- cache_position = kwargs.get("cache_position", None)
1243
- if cache_position is None:
1244
- past_length = 0
1245
- else:
1246
- past_length = cache_position[-1] + 1
1247
- input_ids = input_ids[:, past_length:]
1248
- position_ids = position_ids[:, past_length:]
1249
-
1250
- # we should only keep a `cache_position` in generate, and do +=1.
1251
- # same goes for position ids. Could also help with continued generation.
1252
- cache_position = torch.arange(
1253
- past_length,
1254
- past_length + position_ids.shape[-1],
1255
- device=position_ids.device,
1256
- )
1257
-
1258
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1259
- if inputs_embeds is not None and past_key_values is None:
1260
- model_inputs = {"inputs_embeds": inputs_embeds}
1261
- else:
1262
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1263
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1264
- # We could use `next_tokens` directly instead.
1265
- model_inputs = {"input_ids": input_ids.contiguous()}
1266
-
1267
- model_inputs.update(
1268
- {
1269
- "position_ids": position_ids.contiguous(),
1270
- "cache_position": cache_position,
1271
- "past_key_values": past_key_values,
1272
- "use_cache": kwargs.get("use_cache"),
1273
- "attention_mask": attention_mask,
1274
- }
1275
- )
1276
- return model_inputs
1277
-
1278
- @staticmethod
1279
- def _reorder_cache(past_key_values, beam_idx):
1280
- reordered_past = ()
1281
- for layer_past in past_key_values:
1282
- reordered_past += (
1283
- tuple(
1284
- past_state.index_select(0, beam_idx.to(past_state.device))
1285
- for past_state in layer_past
1286
- ),
1287
- )
1288
- return reordered_past