ClaudiaIoana550 commited on
Commit
2edb465
1 Parent(s): beda033

Rename modeling_falcon.py to modelling_RW.py

Browse files
Files changed (1) hide show
  1. modeling_falcon.py → modelling_RW.py +250 -392
modeling_falcon.py → modelling_RW.py RENAMED
@@ -1,20 +1,9 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
@@ -31,60 +20,59 @@ from transformers.modeling_outputs import (
31
  TokenClassifierOutput,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
- from .configuration_falcon import FalconConfig
36
-
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
- FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
- "tiiuae/falcon-40b",
42
- "tiiuae/falcon-40b-instruct",
43
- "tiiuae/falcon-7b",
44
- "tiiuae/falcon-7b-instruct",
45
- "tiiuae/falcon-rw-7b",
46
- "tiiuae/falcon-rw-1b",
47
- ]
48
- _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
- _CONFIG_FOR_DOC = "FalconConfig"
50
-
51
-
52
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
- class FalconLinear(nn.Linear):
55
  def forward(self, input: torch.Tensor) -> torch.Tensor:
56
- hidden_states = input @ self.weight.T
57
  if self.bias is None:
58
- return hidden_states
59
- return hidden_states + self.bias
 
 
60
 
 
61
 
62
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
  def rotate_half(x):
64
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
 
67
 
68
- class FalconRotaryEmbedding(nn.Module):
69
  """Implementation of RotaryEmbedding from GPT-NeoX.
70
- This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
- n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
  """
73
 
74
- def __init__(self, head_dim: int, base=10000):
 
 
 
 
75
  super().__init__()
76
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
  self.register_buffer("inv_freq", inv_freq, persistent=False)
78
  self.head_dim = head_dim
79
- self.seq_len_cached = -1
 
80
  self.cos_cached: torch.Tensor | None = None
81
  self.sin_cached: torch.Tensor | None = None
82
 
83
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
- total_length = seq_len + past_key_values_length
85
- if total_length > self.seq_len_cached:
86
- self.seq_len_cached = total_length
87
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
88
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
 
@@ -97,46 +85,36 @@ class FalconRotaryEmbedding(nn.Module):
97
  self.cos_cached = self.cos_cached.type(dtype)
98
  self.sin_cached = self.sin_cached.type(dtype)
99
 
100
- return (
101
- self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
- self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
- )
104
 
105
- def forward(self, query, key, past_key_values_length=0):
106
- batch, seq_len, head_dim = query.shape
107
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
 
110
 
111
  def _make_causal_mask(
112
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
  ) -> torch.BoolTensor:
114
- """
115
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
- target_length, target_length+past_key_values_length]`.
118
- """
119
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
120
 
121
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
- mask = torch.cat([past_mask, mask], dim=-1)
127
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
  return expanded_mask
129
 
130
 
131
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
- """
133
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
- """
135
- batch_size, total_length = mask.shape
136
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
 
138
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
 
141
 
142
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -167,31 +145,18 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
167
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
 
169
 
170
- # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
- """
173
- Dropout add function
174
- Args:
175
- x (`torch.tensor`, *required*):
176
- input tensor
177
- residual (`torch.tensor`, *required*):
178
- residual tensor
179
- prob (`float`, *required*):
180
- dropout probability
181
- training (`bool`, *required*):
182
- training mode
183
- """
184
  out = F.dropout(x, p=prob, training=training)
185
  out = residual + out
186
  return out
187
 
188
 
189
- class FalconAttention(nn.Module):
190
- def __init__(self, config: FalconConfig):
191
  super().__init__()
192
 
193
  self.hidden_size = config.hidden_size
194
- self.num_heads = config.num_attention_heads
195
  self.head_dim = self.hidden_size // self.num_heads
196
  self.split_size = self.hidden_size
197
  self.hidden_dropout = config.hidden_dropout
@@ -202,45 +167,35 @@ class FalconAttention(nn.Module):
202
  f" {self.num_heads})."
203
  )
204
 
205
- self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
206
 
207
  # Layer-wise attention scaling
208
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
209
  self.beta = self.inv_norm_factor
210
- if config.new_decoder_architecture:
211
- qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
212
- elif config.multi_query:
213
- qkv_out_dim = self.hidden_size + 2 * self.head_dim
214
- else:
215
- qkv_out_dim = 3 * self.hidden_size
216
- self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
217
- self.new_decoder_architecture = config.new_decoder_architecture
218
  self.multi_query = config.multi_query
219
- self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
220
  self.attention_dropout = nn.Dropout(config.attention_dropout)
221
- self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
222
 
223
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
224
  """
225
- Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
 
226
  Args:
227
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 
228
  Returns:
229
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
230
  value: [batch_size, seq_length, num_heads, head_dim]
231
  """
232
- if self.new_decoder_architecture:
233
- batch, seq_len, _ = fused_qkv.shape
234
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
235
- query = qkv[:, :, :, :-2]
236
- key = qkv[:, :, :, [-2]]
237
- value = qkv[:, :, :, [-1]]
238
- key = torch.broadcast_to(key, query.shape)
239
- value = torch.broadcast_to(value, query.shape)
240
-
241
- query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
242
- return query, key, value
243
- elif not self.multi_query:
244
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
245
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
246
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
@@ -249,12 +204,13 @@ class FalconAttention(nn.Module):
249
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
250
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
251
 
252
- # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
253
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
254
  """
255
  Merge heads together over the last dimenstion
 
256
  Args:
257
- x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
 
258
  Returns:
259
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
260
  """
@@ -276,7 +232,7 @@ class FalconAttention(nn.Module):
276
  def forward(
277
  self,
278
  hidden_states: torch.Tensor,
279
- alibi: Optional[torch.Tensor],
280
  attention_mask: torch.Tensor,
281
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
282
  head_mask: Optional[torch.Tensor] = None,
@@ -284,120 +240,105 @@ class FalconAttention(nn.Module):
284
  output_attentions: bool = False,
285
  ):
286
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
287
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
288
  # 3 x [batch_size, seq_length, num_heads, head_dim]
289
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
290
 
291
- batch_size, query_length, _, _ = query_layer.shape
292
 
293
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
294
  key_layer = key_layer.transpose(1, 2).reshape(
295
- batch_size * num_kv_heads,
296
- query_length,
297
  self.head_dim,
298
  )
299
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
300
 
301
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
302
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
303
 
304
  if layer_past is not None:
305
  past_key, past_value = layer_past
306
  # concatenate along seq_length dimension:
307
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
308
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
309
  key_layer = torch.cat((past_key, key_layer), dim=1)
310
  value_layer = torch.cat((past_value, value_layer), dim=1)
311
 
312
  _, kv_length, _ = key_layer.shape
313
- if use_cache:
 
314
  present = (key_layer, value_layer)
315
  else:
316
  present = None
317
 
318
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
319
-
320
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
321
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
322
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
323
-
324
  if alibi is None:
325
- if output_attentions:
326
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
327
- # to do it by hand if we want them
328
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
329
- attention_scores /= math.sqrt(self.head_dim)
330
 
331
- attention_scores = F.softmax(
332
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
333
- )
334
- attn_output = attention_scores @ value_layer_
335
- else:
336
- attn_output = F.scaled_dot_product_attention(
337
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
338
- )
339
- attention_scores = None
340
 
341
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
342
- attn_output = attn_output.permute(0, 2, 1, 3)
343
- attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
344
 
345
  output_tensor = self.dense(attn_output)
346
 
347
- if output_attentions:
348
- return output_tensor, present, attention_scores
349
- else:
350
- return output_tensor, present
351
-
352
  else:
353
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
354
 
355
  # change view to [batch_size, num_heads, q_length, kv_length]
356
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
357
 
358
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
359
  input_dtype = attention_scores.dtype
360
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
361
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
362
  attention_scores = attention_scores.to(torch.float32)
363
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
364
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
365
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
366
- # and you'd like to experiment and maybe file a PR, feel free!
367
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
368
- attention_logits *= self.inv_norm_factor
369
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
370
  # [batch_size, num_heads, q_length, kv_length]
371
  attention_probs = self.attention_dropout(attention_probs)
372
 
373
  if head_mask is not None:
374
  attention_probs = attention_probs * head_mask
375
 
376
- # change view [batch_size, num_heads, q_length, kv_length]
377
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
378
 
379
  # matmul: [batch_size * num_heads, q_length, head_dim]
380
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
381
 
382
  # change view [batch_size, num_heads, q_length, head_dim]
383
  context_layer = self._merge_heads(context_layer)
384
 
385
  output_tensor = self.dense(context_layer)
386
 
 
387
  if output_attentions:
388
- return output_tensor, present, attention_probs
389
- else:
390
- return output_tensor, present
391
 
392
 
393
- class FalconMLP(nn.Module):
394
- def __init__(self, config: FalconConfig):
395
  super().__init__()
396
  hidden_size = config.hidden_size
397
 
398
- self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
399
  self.act = nn.GELU()
400
- self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
401
  self.hidden_dropout = config.hidden_dropout
402
 
403
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -406,47 +347,43 @@ class FalconMLP(nn.Module):
406
  return x
407
 
408
 
409
- class FalconDecoderLayer(nn.Module):
410
- def __init__(self, config: FalconConfig):
411
  super().__init__()
412
  hidden_size = config.hidden_size
413
- self.num_heads = config.num_attention_heads
414
- self.self_attention = FalconAttention(config)
415
- self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
 
 
416
  self.hidden_dropout = config.hidden_dropout
417
- self.config = config
418
 
419
- if config.new_decoder_architecture:
420
- # The layer norm before self-attention
421
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
422
- # The layer norm before the MLP
423
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
424
- else:
425
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
426
- if not config.parallel_attn:
427
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
428
 
429
  def forward(
430
  self,
431
  hidden_states: torch.Tensor,
432
- alibi: Optional[torch.Tensor],
433
  attention_mask: torch.Tensor,
434
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
435
  head_mask: Optional[torch.Tensor] = None,
436
  use_cache: bool = False,
437
  output_attentions: bool = False,
438
  ):
439
- residual = hidden_states
440
 
441
- if self.config.new_decoder_architecture:
442
- attention_layernorm_out = self.ln_attn(hidden_states)
443
- mlp_layernorm_out = self.ln_mlp(hidden_states)
444
- else:
445
- attention_layernorm_out = self.input_layernorm(hidden_states)
446
 
447
  # Self attention.
448
  attn_outputs = self.self_attention(
449
- attention_layernorm_out,
450
  layer_past=layer_past,
451
  attention_mask=attention_mask,
452
  alibi=alibi,
@@ -457,21 +394,16 @@ class FalconDecoderLayer(nn.Module):
457
 
458
  attention_output = attn_outputs[0]
459
 
460
- if not self.config.new_decoder_architecture:
461
- if self.config.parallel_attn:
462
- mlp_layernorm_out = attention_layernorm_out
463
- else:
464
- residual = dropout_add(
465
- attention_output, residual, self.config.attention_dropout, training=self.training
466
- )
467
- mlp_layernorm_out = self.post_attention_layernorm(residual)
468
 
469
  outputs = attn_outputs[1:]
470
 
471
  # MLP.
472
- mlp_output = self.mlp(mlp_layernorm_out)
473
 
474
- if self.config.new_decoder_architecture or self.config.parallel_attn:
475
  mlp_output += attention_output
476
 
477
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
@@ -484,81 +416,24 @@ class FalconDecoderLayer(nn.Module):
484
  return outputs # hidden_states, present, attentions
485
 
486
 
487
- FALCON_START_DOCSTRING = r"""
488
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
489
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
490
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
491
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
492
- and behavior.
493
- Parameters:
494
- config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
495
- Initializing with a config file does not load the weights associated with the model, only the
496
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
497
- """
498
-
499
- FALCON_INPUTS_DOCSTRING = r"""
500
- Args:
501
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
502
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
503
- (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
504
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
505
- `input_ids`.
506
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
507
- [`PreTrainedTokenizer.__call__`] for details.
508
- [What are input IDs?](../glossary#input-ids)
509
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
510
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
511
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
512
- their past given to this model should not be passed as `input_ids` as they have already been computed.
513
- Each element of `past_key_values` is a tuple (past_key, past_value):
514
- - past_key: [batch_size * num_heads, head_dim, kv_length]
515
- - past_value: [batch_size * num_heads, kv_length, head_dim]
516
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
517
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
518
- - 1 for tokens that are **not masked**,
519
- - 0 for tokens that are **masked**.
520
- [What are attention masks?](../glossary#attention-mask)
521
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
522
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
523
- - 1 indicates the head is **not masked**,
524
- - 0 indicates the head is **masked**.
525
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
526
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
527
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
528
- model's internal embedding lookup matrix.
529
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
530
- `past_key_values`).
531
- use_cache (`bool`, *optional*):
532
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
533
- `past_key_values`).
534
- output_attentions (`bool`, *optional*):
535
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
536
- tensors for more detail.
537
- output_hidden_states (`bool`, *optional*):
538
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
539
- more detail.
540
- return_dict (`bool`, *optional*):
541
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
542
- """
543
-
544
-
545
- class FalconPreTrainedModel(PreTrainedModel):
546
  """
547
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
548
  models.
549
  """
550
 
551
- config_class = FalconConfig
552
  base_model_prefix = "transformer"
553
  supports_gradient_checkpointing = True
554
- _no_split_modules = ["FalconDecoderLayer"]
555
 
556
  def __init__(self, *inputs, **kwargs):
557
  super().__init__(*inputs, **kwargs)
558
 
559
  def _init_weights(self, module: nn.Module):
560
  """Initialize the weights."""
561
- if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
562
  # Slightly different from the TF version which uses truncated_normal for initialization
563
  # cf https://github.com/pytorch/pytorch/pull/5617
564
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -572,28 +447,26 @@ class FalconPreTrainedModel(PreTrainedModel):
572
  module.bias.data.zero_()
573
  module.weight.data.fill_(1.0)
574
 
575
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
576
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
577
- if isinstance(module, FalconModel):
578
  module.gradient_checkpointing = value
579
 
580
  @staticmethod
581
- def _convert_cache_to_standard_format(
582
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
583
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
584
  """
585
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
586
  num_heads, ...]))
587
  """
588
- batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
589
- # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
590
- # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
591
- # on whether we use multi_query attention.
592
  num_heads = batch_size_times_num_heads // batch_size
 
 
593
  return tuple(
594
  (
595
- layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
596
- layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
597
  )
598
  for layer_past in past_key_value
599
  )
@@ -602,35 +475,32 @@ class FalconPreTrainedModel(PreTrainedModel):
602
  def _convert_to_rw_cache(
603
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
604
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
605
- batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
  batch_size_times_num_heads = batch_size * num_heads
607
- # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
608
  return tuple(
609
  (
610
- layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
611
- layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
612
  )
613
  for layer_past in past_key_value
614
  )
615
 
616
 
617
- @add_start_docstrings(
618
- "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
619
- FALCON_START_DOCSTRING,
620
- )
621
- class FalconModel(FalconPreTrainedModel):
622
- def __init__(self, config: FalconConfig):
623
  super().__init__(config)
624
 
625
  self.embed_dim = config.hidden_size
626
- self.num_heads = config.num_attention_heads
627
- self.use_alibi = config.alibi
628
 
629
  # Embedding + LN Embedding
630
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
631
 
632
  # Transformer blocks
633
- self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
634
 
635
  # Final Layer Norm
636
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -643,31 +513,22 @@ class FalconModel(FalconPreTrainedModel):
643
  def get_input_embeddings(self):
644
  return self.word_embeddings
645
 
646
- @staticmethod
647
  def _prepare_attn_mask(
648
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
649
  ) -> torch.BoolTensor:
650
- # Create a causal mask
651
- # The attention mask we receive as input should cover the whole extended sequence, including any past
652
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
653
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
654
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
655
- raise ValueError(
656
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
657
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
658
- f" {past_key_values_length}."
659
- )
660
  combined_attention_mask = None
661
  device = attention_mask.device
662
- _, seq_length = input_shape
663
 
664
- if seq_length > 1:
665
  combined_attention_mask = _make_causal_mask(
666
  input_shape, device=device, past_key_values_length=past_key_values_length
667
  )
668
 
669
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
670
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
671
  combined_attention_mask = (
672
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
673
  )
@@ -677,12 +538,6 @@ class FalconModel(FalconPreTrainedModel):
677
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
678
  self.word_embeddings = new_embeddings
679
 
680
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
681
- @add_code_sample_docstrings(
682
- checkpoint=_CHECKPOINT_FOR_DOC,
683
- output_type=BaseModelOutputWithPastAndCrossAttentions,
684
- config_class=_CONFIG_FOR_DOC,
685
- )
686
  def forward(
687
  self,
688
  input_ids: Optional[torch.LongTensor] = None,
@@ -694,7 +549,18 @@ class FalconModel(FalconPreTrainedModel):
694
  output_attentions: Optional[bool] = None,
695
  output_hidden_states: Optional[bool] = None,
696
  return_dict: Optional[bool] = None,
 
697
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
698
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
  output_hidden_states = (
700
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -713,14 +579,12 @@ class FalconModel(FalconPreTrainedModel):
713
 
714
  if past_key_values is None:
715
  past_key_values = tuple([None] * len(self.h))
716
- else:
717
- past_key_values = self._convert_to_rw_cache(past_key_values)
718
 
719
  # Prepare head mask if needed
720
  # 1.0 in head_mask indicate we keep the head
721
  # attention_probs has shape batch_size x num_heads x N x N
722
  # head_mask has shape n_layer x batch x num_heads x N x N
723
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
724
 
725
  if inputs_embeds is None:
726
  inputs_embeds = self.word_embeddings(input_ids)
@@ -732,15 +596,17 @@ class FalconModel(FalconPreTrainedModel):
732
  all_hidden_states = () if output_hidden_states else None
733
 
734
  # Compute alibi tensor: check build_alibi_tensor documentation
 
735
  past_key_values_length = 0
736
  if past_key_values[0] is not None:
737
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
738
  if attention_mask is None:
739
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
740
  else:
741
  attention_mask = attention_mask.to(hidden_states.device)
742
 
743
- if self.use_alibi:
744
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
745
  else:
746
  alibi = None
@@ -752,10 +618,12 @@ class FalconModel(FalconPreTrainedModel):
752
  )
753
 
754
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
755
  if output_hidden_states:
756
  all_hidden_states = all_hidden_states + (hidden_states,)
757
 
758
  if self.gradient_checkpointing and self.training:
 
759
  if use_cache:
760
  logger.warning(
761
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -800,9 +668,6 @@ class FalconModel(FalconPreTrainedModel):
800
  if output_hidden_states:
801
  all_hidden_states = all_hidden_states + (hidden_states,)
802
 
803
- if presents is not None:
804
- presents = self._convert_cache_to_standard_format(presents, batch_size)
805
-
806
  if not return_dict:
807
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
808
 
@@ -814,16 +679,12 @@ class FalconModel(FalconPreTrainedModel):
814
  )
815
 
816
 
817
- @add_start_docstrings(
818
- "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
819
- FALCON_START_DOCSTRING,
820
- )
821
- class FalconForCausalLM(FalconPreTrainedModel):
822
- _tied_weights_keys = ["lm_head.weight"]
823
 
824
- def __init__(self, config: FalconConfig):
825
  super().__init__(config)
826
- self.transformer = FalconModel(config)
827
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
828
 
829
  # Initialize weights and apply final processing
@@ -838,26 +699,25 @@ class FalconForCausalLM(FalconPreTrainedModel):
838
  def prepare_inputs_for_generation(
839
  self,
840
  input_ids: torch.LongTensor,
841
- past_key_values: Optional[torch.Tensor] = None,
842
  attention_mask: Optional[torch.Tensor] = None,
843
  **kwargs,
844
  ) -> dict:
845
- if past_key_values is not None:
846
- input_ids = input_ids[:, -1:]
 
 
 
 
 
847
 
848
  return {
849
  "input_ids": input_ids,
850
- "past_key_values": past_key_values,
851
  "use_cache": kwargs.get("use_cache"),
852
  "attention_mask": attention_mask,
853
  }
854
 
855
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
856
- @add_code_sample_docstrings(
857
- checkpoint=_CHECKPOINT_FOR_DOC,
858
- output_type=CausalLMOutputWithCrossAttentions,
859
- config_class=_CONFIG_FOR_DOC,
860
- )
861
  def forward(
862
  self,
863
  input_ids: Optional[torch.LongTensor] = None,
@@ -870,6 +730,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
870
  output_attentions: Optional[bool] = None,
871
  output_hidden_states: Optional[bool] = None,
872
  return_dict: Optional[bool] = None,
 
873
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
874
  r"""
875
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -877,6 +738,15 @@ class FalconForCausalLM(FalconPreTrainedModel):
877
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
878
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
879
  """
 
 
 
 
 
 
 
 
 
880
 
881
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
882
 
@@ -926,8 +796,10 @@ class FalconForCausalLM(FalconPreTrainedModel):
926
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
927
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
928
  beam_idx at every generation step.
 
929
  Output shares the same memory storage as `past`.
930
  """
 
931
 
932
  # Get a copy of `beam_idx` on all the devices where we need those indices.
933
  device_to_beam_idx = {
@@ -938,40 +810,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
938
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
939
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
940
  )
941
- for layer_past in past
942
  )
943
- return reordered_past
944
 
945
 
946
- @add_start_docstrings(
947
- """
948
- The Falcon Model transformer with a sequence classification head on top (linear layer).
949
- [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
950
- (e.g. GPT-1) do.
951
- Since it does classification on the last token, it requires to know the position of the last token. If a
952
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
953
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
954
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
955
- each row of the batch).
956
- """,
957
- FALCON_START_DOCSTRING,
958
- )
959
- class FalconForSequenceClassification(FalconPreTrainedModel):
960
- def __init__(self, config: FalconConfig):
961
  super().__init__(config)
962
  self.num_labels = config.num_labels
963
- self.transformer = FalconModel(config)
964
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
965
 
966
  # Initialize weights and apply final processing
967
  self.post_init()
968
 
969
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
970
- @add_code_sample_docstrings(
971
- checkpoint=_CHECKPOINT_FOR_DOC,
972
- output_type=SequenceClassifierOutputWithPast,
973
- config_class=_CONFIG_FOR_DOC,
974
- )
975
  def forward(
976
  self,
977
  input_ids: Optional[torch.LongTensor] = None,
@@ -984,6 +839,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
984
  output_attentions: Optional[bool] = None,
985
  output_hidden_states: Optional[bool] = None,
986
  return_dict: Optional[bool] = None,
 
987
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
988
  r"""
989
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -991,6 +847,15 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
991
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
992
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
993
  """
 
 
 
 
 
 
 
 
 
994
 
995
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
996
 
@@ -1065,22 +930,17 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1065
  )
1066
 
1067
 
1068
- @add_start_docstrings(
1069
- """
1070
- Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1071
- Named-Entity-Recognition (NER) tasks.
1072
- """,
1073
- FALCON_START_DOCSTRING,
1074
- )
1075
- class FalconForTokenClassification(FalconPreTrainedModel):
1076
- def __init__(self, config: FalconConfig):
1077
  super().__init__(config)
1078
  self.num_labels = config.num_labels
1079
 
1080
- self.transformer = FalconModel(config)
1081
- if getattr(config, "classifier_dropout", None) is not None:
1082
  classifier_dropout = config.classifier_dropout
1083
- elif getattr(config, "hidden_dropout", None) is not None:
1084
  classifier_dropout = config.hidden_dropout
1085
  else:
1086
  classifier_dropout = 0.1
@@ -1090,12 +950,6 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1090
  # Initialize weights and apply final processing
1091
  self.post_init()
1092
 
1093
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1094
- @add_code_sample_docstrings(
1095
- checkpoint=_CHECKPOINT_FOR_DOC,
1096
- output_type=TokenClassifierOutput,
1097
- config_class=_CONFIG_FOR_DOC,
1098
- )
1099
  def forward(
1100
  self,
1101
  input_ids: Optional[torch.LongTensor] = None,
@@ -1108,6 +962,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1108
  output_attentions: Optional[bool] = None,
1109
  output_hidden_states: Optional[bool] = None,
1110
  return_dict: Optional[bool] = None,
 
1111
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1112
  r"""
1113
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1115,6 +970,15 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1115
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1116
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1117
  """
 
 
 
 
 
 
 
 
 
1118
 
1119
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1120
 
@@ -1138,9 +1002,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1138
  if labels is not None:
1139
  batch_size, seq_length = labels.shape
1140
  loss_fct = CrossEntropyLoss()
1141
- loss = loss_fct(
1142
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1143
- )
1144
 
1145
  if not return_dict:
1146
  output = (logits,) + transformer_outputs[2:]
@@ -1154,27 +1016,22 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1154
  )
1155
 
1156
 
1157
- @add_start_docstrings(
1158
- """
1159
- The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1160
- SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1161
- """,
1162
- FALCON_START_DOCSTRING,
1163
- )
1164
- class FalconForQuestionAnswering(FalconPreTrainedModel):
1165
  def __init__(self, config):
1166
  super().__init__(config)
1167
- self.transformer = FalconModel(config)
1168
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1169
 
1170
  # Initialize weights and apply final processing
1171
  self.post_init()
1172
 
1173
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1174
  def forward(
1175
  self,
1176
  input_ids: Optional[torch.LongTensor] = None,
1177
  attention_mask: Optional[torch.FloatTensor] = None,
 
1178
  head_mask: Optional[torch.FloatTensor] = None,
1179
  inputs_embeds: Optional[torch.FloatTensor] = None,
1180
  start_positions: Optional[torch.LongTensor] = None,
@@ -1198,6 +1055,7 @@ class FalconForQuestionAnswering(FalconPreTrainedModel):
1198
  outputs = self.transformer(
1199
  input_ids,
1200
  attention_mask=attention_mask,
 
1201
  head_mask=head_mask,
1202
  inputs_embeds=inputs_embeds,
1203
  output_attentions=output_attentions,
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
+ import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
20
  TokenClassifierOutput,
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
+ class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ ret = input @ self.weight.T
33
  if self.bias is None:
34
+ return ret
35
+ else:
36
+ return ret + self.bias
37
+
38
 
39
+ from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
+ class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
+ This implementation is design to operate on queries and keys that are compatible with
50
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
+ def __init__(
54
+ self,
55
+ head_dim: int,
56
+ base=10000,
57
+ ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
+ self.seq_len_cached = None
63
+ self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
+ def cos_sin(
68
+ self,
69
+ seq_len: int,
70
+ device="cuda",
71
+ dtype=torch.bfloat16,
72
+ ) -> torch.Tensor:
73
+ if seq_len != self.seq_len_cached:
74
+ self.seq_len_cached = seq_len
75
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
 
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
+ return self.cos_cached, self.sin_cached
 
 
 
89
 
90
+ def forward(self, q, k):
91
+ batch, seq_len, head_dim = q.shape
92
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
+ seq_ids = torch.arange(target_length, device=device)
103
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
+
105
+ if past_key_values_length > 0:
106
+ mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
+ batch_size, src_length = mask.shape
114
+ tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
+ class Attention(nn.Module):
155
+ def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
+ self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
 
167
  f" {self.num_heads})."
168
  )
169
 
170
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
+
176
+ self.query_key_value = Linear(
177
+ self.hidden_size,
178
+ 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
+ bias=config.bias,
180
+ )
 
 
181
  self.multi_query = config.multi_query
182
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
  self.attention_dropout = nn.Dropout(config.attention_dropout)
184
+ self.num_kv = config.n_head if not self.multi_query else 1
185
 
186
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
  """
188
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
+ storage as `fused_qkv`
190
+
191
  Args:
192
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
193
+
194
  Returns:
195
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
+ if not self.multi_query:
 
 
 
 
 
 
 
 
 
 
 
199
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
200
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
201
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
 
204
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
205
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
206
 
 
207
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
208
  """
209
  Merge heads together over the last dimenstion
210
+
211
  Args:
212
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
213
+
214
  Returns:
215
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
216
  """
 
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
235
+ alibi: torch.Tensor,
236
  attention_mask: torch.Tensor,
237
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
238
  head_mask: Optional[torch.Tensor] = None,
 
240
  output_attentions: bool = False,
241
  ):
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
+
244
  # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
+ batch_size, q_length, _, _ = query_layer.shape
248
 
249
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
250
  key_layer = key_layer.transpose(1, 2).reshape(
251
+ batch_size * self.num_kv,
252
+ q_length,
253
  self.head_dim,
254
  )
255
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
  # concatenate along seq_length dimension:
262
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
266
 
267
  _, kv_length, _ = key_layer.shape
268
+
269
+ if use_cache is True:
270
  present = (key_layer, value_layer)
271
  else:
272
  present = None
273
 
 
 
 
 
 
 
274
  if alibi is None:
275
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
278
 
279
+ attn_output = F.scaled_dot_product_attention(
280
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
+ )
 
 
 
 
 
 
282
 
283
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
+ x = x.permute(0, 2, 1, 3)
285
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
286
 
287
  output_tensor = self.dense(attn_output)
288
 
289
+ outputs = (output_tensor, present)
290
+ assert not output_attentions # not supported.
291
+ return outputs
 
 
292
  else:
293
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
294
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]
297
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
298
 
299
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
300
  input_dtype = attention_scores.dtype
301
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
302
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
303
  attention_scores = attention_scores.to(torch.float32)
304
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
+ attention_probs = F.softmax(
306
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
307
+ dim=-1,
308
+ dtype=hidden_states.dtype,
309
+ )
 
310
  # [batch_size, num_heads, q_length, kv_length]
311
  attention_probs = self.attention_dropout(attention_probs)
312
 
313
  if head_mask is not None:
314
  attention_probs = attention_probs * head_mask
315
 
316
+ # change view [batch_size x num_heads, q_length, kv_length]
317
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
318
 
319
  # matmul: [batch_size * num_heads, q_length, head_dim]
320
+ context_layer = attention_probs_reshaped @ value_layer
321
 
322
  # change view [batch_size, num_heads, q_length, head_dim]
323
  context_layer = self._merge_heads(context_layer)
324
 
325
  output_tensor = self.dense(context_layer)
326
 
327
+ outputs = (output_tensor, present)
328
  if output_attentions:
329
+ outputs += (attention_probs,)
330
+
331
+ return outputs
332
 
333
 
334
+ class MLP(nn.Module):
335
+ def __init__(self, config: RWConfig):
336
  super().__init__()
337
  hidden_size = config.hidden_size
338
 
339
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
340
  self.act = nn.GELU()
341
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
342
  self.hidden_dropout = config.hidden_dropout
343
 
344
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
347
  return x
348
 
349
 
350
+ class DecoderLayer(nn.Module):
351
+ def __init__(self, config: RWConfig):
352
  super().__init__()
353
  hidden_size = config.hidden_size
354
+
355
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
356
+ self.num_heads = config.n_head
357
+ self.self_attention = Attention(config)
358
+
359
+ if not config.parallel_attn:
360
+ # unused if parallel attn
361
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
362
+
363
+ self.mlp = MLP(config)
364
+
365
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
366
  self.hidden_dropout = config.hidden_dropout
 
367
 
368
+ self.config = config
 
 
 
 
 
 
 
 
369
 
370
  def forward(
371
  self,
372
  hidden_states: torch.Tensor,
373
+ alibi: torch.Tensor,
374
  attention_mask: torch.Tensor,
375
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
  head_mask: Optional[torch.Tensor] = None,
377
  use_cache: bool = False,
378
  output_attentions: bool = False,
379
  ):
 
380
 
381
+ layernorm_output = self.input_layernorm(hidden_states)
382
+ residual = hidden_states
 
 
 
383
 
384
  # Self attention.
385
  attn_outputs = self.self_attention(
386
+ layernorm_output,
387
  layer_past=layer_past,
388
  attention_mask=attention_mask,
389
  alibi=alibi,
 
394
 
395
  attention_output = attn_outputs[0]
396
 
397
+ if not self.config.parallel_attn:
398
+ residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
399
+ layernorm_output = self.post_attention_layernorm(residual)
 
 
 
 
 
400
 
401
  outputs = attn_outputs[1:]
402
 
403
  # MLP.
404
+ mlp_output = self.mlp(layernorm_output)
405
 
406
+ if self.config.parallel_attn:
407
  mlp_output += attention_output
408
 
409
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
416
  return outputs # hidden_states, present, attentions
417
 
418
 
419
+ class RWPreTrainedModel(PreTrainedModel):
420
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  """
422
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
  models.
424
  """
425
 
426
+ config_class = RWConfig
427
  base_model_prefix = "transformer"
428
  supports_gradient_checkpointing = True
429
+ _no_split_modules = ["DecoderLayer"]
430
 
431
  def __init__(self, *inputs, **kwargs):
432
  super().__init__(*inputs, **kwargs)
433
 
434
  def _init_weights(self, module: nn.Module):
435
  """Initialize the weights."""
436
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
437
  # Slightly different from the TF version which uses truncated_normal for initialization
438
  # cf https://github.com/pytorch/pytorch/pull/5617
439
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
447
  module.bias.data.zero_()
448
  module.weight.data.fill_(1.0)
449
 
 
450
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
451
+ if isinstance(module, RWModel):
452
  module.gradient_checkpointing = value
453
 
454
  @staticmethod
455
+ def _convert_to_standard_cache(
456
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
457
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
458
  """
459
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
460
  num_heads, ...]))
461
  """
462
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
463
  num_heads = batch_size_times_num_heads // batch_size
464
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
465
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
466
  return tuple(
467
  (
468
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
469
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
470
  )
471
  for layer_past in past_key_value
472
  )
 
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
  batch_size_times_num_heads = batch_size * num_heads
480
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
483
  (
484
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
485
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
486
  )
487
  for layer_past in past_key_value
488
  )
489
 
490
 
491
+ class RWModel(RWPreTrainedModel):
492
+ def __init__(self, config: RWConfig):
 
 
 
 
493
  super().__init__(config)
494
 
495
  self.embed_dim = config.hidden_size
496
+ self.num_heads = config.n_head
497
+ self.alibi = config.alibi
498
 
499
  # Embedding + LN Embedding
500
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
501
 
502
  # Transformer blocks
503
+ self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
504
 
505
  # Final Layer Norm
506
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
513
  def get_input_embeddings(self):
514
  return self.word_embeddings
515
 
 
516
  def _prepare_attn_mask(
517
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
518
  ) -> torch.BoolTensor:
519
+ # create causal mask
520
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
521
  combined_attention_mask = None
522
  device = attention_mask.device
523
+ _, src_length = input_shape
524
 
525
+ if src_length > 1:
526
  combined_attention_mask = _make_causal_mask(
527
  input_shape, device=device, past_key_values_length=past_key_values_length
528
  )
529
 
530
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
532
  combined_attention_mask = (
533
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
534
  )
 
538
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
539
  self.word_embeddings = new_embeddings
540
 
 
 
 
 
 
 
541
  def forward(
542
  self,
543
  input_ids: Optional[torch.LongTensor] = None,
 
549
  output_attentions: Optional[bool] = None,
550
  output_hidden_states: Optional[bool] = None,
551
  return_dict: Optional[bool] = None,
552
+ **deprecated_arguments,
553
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
554
+ if deprecated_arguments.pop("position_ids", False) is not False:
555
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
556
+ warnings.warn(
557
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
558
+ " passing `position_ids`.",
559
+ FutureWarning,
560
+ )
561
+ if len(deprecated_arguments) > 0:
562
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
+
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
579
 
580
  if past_key_values is None:
581
  past_key_values = tuple([None] * len(self.h))
 
 
582
 
583
  # Prepare head mask if needed
584
  # 1.0 in head_mask indicate we keep the head
585
  # attention_probs has shape batch_size x num_heads x N x N
586
  # head_mask has shape n_layer x batch x num_heads x N x N
587
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
588
 
589
  if inputs_embeds is None:
590
  inputs_embeds = self.word_embeddings(input_ids)
 
596
  all_hidden_states = () if output_hidden_states else None
597
 
598
  # Compute alibi tensor: check build_alibi_tensor documentation
599
+ seq_length_with_past = seq_length
600
  past_key_values_length = 0
601
  if past_key_values[0] is not None:
602
+ past_key_values_length = past_key_values[0][0].shape[2]
603
+ seq_length_with_past = seq_length_with_past + past_key_values_length
604
  if attention_mask is None:
605
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
606
  else:
607
  attention_mask = attention_mask.to(hidden_states.device)
608
 
609
+ if self.alibi:
610
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
611
  else:
612
  alibi = None
 
618
  )
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
+
622
  if output_hidden_states:
623
  all_hidden_states = all_hidden_states + (hidden_states,)
624
 
625
  if self.gradient_checkpointing and self.training:
626
+
627
  if use_cache:
628
  logger.warning(
629
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
668
  if output_hidden_states:
669
  all_hidden_states = all_hidden_states + (hidden_states,)
670
 
 
 
 
671
  if not return_dict:
672
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
673
 
 
679
  )
680
 
681
 
682
+ class RWForCausalLM(RWPreTrainedModel):
683
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
684
 
685
+ def __init__(self, config: RWConfig):
686
  super().__init__(config)
687
+ self.transformer = RWModel(config)
688
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
689
 
690
  # Initialize weights and apply final processing
 
699
  def prepare_inputs_for_generation(
700
  self,
701
  input_ids: torch.LongTensor,
702
+ past: Optional[torch.Tensor] = None,
703
  attention_mask: Optional[torch.Tensor] = None,
704
  **kwargs,
705
  ) -> dict:
706
+ # only last token for input_ids if past is not None
707
+ if past:
708
+ input_ids = input_ids[:, -1].unsqueeze(-1)
709
+
710
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
+ if past[0][0].shape[0] == input_ids.shape[0]:
712
+ past = self._convert_to_rw_cache(past)
713
 
714
  return {
715
  "input_ids": input_ids,
716
+ "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
720
 
 
 
 
 
 
 
721
  def forward(
722
  self,
723
  input_ids: Optional[torch.LongTensor] = None,
 
730
  output_attentions: Optional[bool] = None,
731
  output_hidden_states: Optional[bool] = None,
732
  return_dict: Optional[bool] = None,
733
+ **deprecated_arguments,
734
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
735
  r"""
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
738
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
739
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
740
  """
741
+ if deprecated_arguments.pop("position_ids", False) is not False:
742
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
743
+ warnings.warn(
744
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
745
+ " passing `position_ids`.",
746
+ FutureWarning,
747
+ )
748
+ if len(deprecated_arguments) > 0:
749
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
750
 
751
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
752
 
 
796
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
797
  [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
798
  beam_idx at every generation step.
799
+
800
  Output shares the same memory storage as `past`.
801
  """
802
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
803
 
804
  # Get a copy of `beam_idx` on all the devices where we need those indices.
805
  device_to_beam_idx = {
 
810
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
811
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
812
  )
813
+ for layer_past in standardized_past
814
  )
815
+ return self._convert_to_rw_cache(reordered_past)
816
 
817
 
818
+ class RWForSequenceClassification(RWPreTrainedModel):
819
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
820
+
821
+ def __init__(self, config: RWConfig):
 
 
 
 
 
 
 
 
 
 
 
822
  super().__init__(config)
823
  self.num_labels = config.num_labels
824
+ self.transformer = RWModel(config)
825
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
826
 
827
  # Initialize weights and apply final processing
828
  self.post_init()
829
 
 
 
 
 
 
 
830
  def forward(
831
  self,
832
  input_ids: Optional[torch.LongTensor] = None,
 
839
  output_attentions: Optional[bool] = None,
840
  output_hidden_states: Optional[bool] = None,
841
  return_dict: Optional[bool] = None,
842
+ **deprecated_arguments,
843
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
844
  r"""
845
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
847
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
  """
850
+ if deprecated_arguments.pop("position_ids", False) is not False:
851
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
852
+ warnings.warn(
853
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
854
+ " passing `position_ids`.",
855
+ FutureWarning,
856
+ )
857
+ if len(deprecated_arguments) > 0:
858
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
859
 
860
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
861
 
 
930
  )
931
 
932
 
933
+ class RWForTokenClassification(RWPreTrainedModel):
934
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
935
+
936
+ def __init__(self, config: RWConfig):
 
 
 
 
 
937
  super().__init__(config)
938
  self.num_labels = config.num_labels
939
 
940
+ self.transformer = RWModel(config)
941
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
942
  classifier_dropout = config.classifier_dropout
943
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
944
  classifier_dropout = config.hidden_dropout
945
  else:
946
  classifier_dropout = 0.1
 
950
  # Initialize weights and apply final processing
951
  self.post_init()
952
 
 
 
 
 
 
 
953
  def forward(
954
  self,
955
  input_ids: Optional[torch.LongTensor] = None,
 
962
  output_attentions: Optional[bool] = None,
963
  output_hidden_states: Optional[bool] = None,
964
  return_dict: Optional[bool] = None,
965
+ **deprecated_arguments,
966
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
967
  r"""
968
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
970
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
971
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
972
  """
973
+ if deprecated_arguments.pop("position_ids", False) is not False:
974
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
975
+ warnings.warn(
976
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
977
+ " passing `position_ids`.",
978
+ FutureWarning,
979
+ )
980
+ if len(deprecated_arguments) > 0:
981
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
982
 
983
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
 
 
1002
  if labels is not None:
1003
  batch_size, seq_length = labels.shape
1004
  loss_fct = CrossEntropyLoss()
1005
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1006
 
1007
  if not return_dict:
1008
  output = (logits,) + transformer_outputs[2:]
 
1016
  )
1017
 
1018
 
1019
+ class RWForQuestionAnswering(RWPreTrainedModel):
1020
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1021
+
 
 
 
 
 
1022
  def __init__(self, config):
1023
  super().__init__(config)
1024
+ self.transformer = RWModel(config)
1025
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1026
 
1027
  # Initialize weights and apply final processing
1028
  self.post_init()
1029
 
 
1030
  def forward(
1031
  self,
1032
  input_ids: Optional[torch.LongTensor] = None,
1033
  attention_mask: Optional[torch.FloatTensor] = None,
1034
+ position_ids: Optional[torch.LongTensor] = None,
1035
  head_mask: Optional[torch.FloatTensor] = None,
1036
  inputs_embeds: Optional[torch.FloatTensor] = None,
1037
  start_positions: Optional[torch.LongTensor] = None,
 
1055
  outputs = self.transformer(
1056
  input_ids,
1057
  attention_mask=attention_mask,
1058
+ position_ids=position_ids,
1059
  head_mask=head_mask,
1060
  inputs_embeds=inputs_embeds,
1061
  output_attentions=output_attentions,