ClaudiaIoana550 commited on
Commit
973ff97
1 Parent(s): c70e30c

Rename modeling_falcon.py to modelling_RW.py

Browse files
Files changed (1) hide show
  1. modeling_falcon.py → modelling_RW.py +245 -407
modeling_falcon.py → modelling_RW.py RENAMED
@@ -1,20 +1,9 @@
1
- # coding=utf-8
2
- # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
@@ -31,60 +20,59 @@ from transformers.modeling_outputs import (
31
  TokenClassifierOutput,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
- from .configuration_falcon import FalconConfig
36
-
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
- FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
- "tiiuae/falcon-40b",
42
- "tiiuae/falcon-40b-instruct",
43
- "tiiuae/falcon-7b",
44
- "tiiuae/falcon-7b-instruct",
45
- "tiiuae/falcon-rw-7b",
46
- "tiiuae/falcon-rw-1b",
47
- ]
48
- _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
- _CONFIG_FOR_DOC = "FalconConfig"
50
-
51
-
52
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
- class FalconLinear(nn.Linear):
55
  def forward(self, input: torch.Tensor) -> torch.Tensor:
56
- hidden_states = input @ self.weight.T
57
  if self.bias is None:
58
- return hidden_states
59
- return hidden_states + self.bias
 
 
60
 
 
61
 
62
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
  def rotate_half(x):
64
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
- return torch.cat((-x2, x1), dim=-1)
66
 
67
 
68
- class FalconRotaryEmbedding(nn.Module):
69
  """Implementation of RotaryEmbedding from GPT-NeoX.
70
- This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
- n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
  """
73
 
74
- def __init__(self, head_dim: int, base=10000):
 
 
 
 
75
  super().__init__()
76
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
  self.register_buffer("inv_freq", inv_freq, persistent=False)
78
  self.head_dim = head_dim
79
- self.seq_len_cached = -1
 
80
  self.cos_cached: torch.Tensor | None = None
81
  self.sin_cached: torch.Tensor | None = None
82
 
83
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
- total_length = seq_len + past_key_values_length
85
- if total_length > self.seq_len_cached:
86
- self.seq_len_cached = total_length
87
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
88
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
 
@@ -97,46 +85,36 @@ class FalconRotaryEmbedding(nn.Module):
97
  self.cos_cached = self.cos_cached.type(dtype)
98
  self.sin_cached = self.sin_cached.type(dtype)
99
 
100
- return (
101
- self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
- self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
- )
104
 
105
- def forward(self, query, key, past_key_values_length=0):
106
- batch, seq_len, head_dim = query.shape
107
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
 
110
 
111
  def _make_causal_mask(
112
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
  ) -> torch.BoolTensor:
114
- """
115
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
- target_length, target_length+past_key_values_length]`.
118
- """
119
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
120
 
121
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
- mask = torch.cat([past_mask, mask], dim=-1)
127
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
  return expanded_mask
129
 
130
 
131
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
- """
133
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
- """
135
- batch_size, total_length = mask.shape
136
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
 
138
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
 
141
 
142
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -167,32 +145,18 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
167
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
 
169
 
170
- # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
- """
173
- Dropout add function
174
-
175
- Args:
176
- x (`torch.tensor`, *required*):
177
- input tensor
178
- residual (`torch.tensor`, *required*):
179
- residual tensor
180
- prob (`float`, *required*):
181
- dropout probability
182
- training (`bool`, *required*):
183
- training mode
184
- """
185
  out = F.dropout(x, p=prob, training=training)
186
  out = residual + out
187
  return out
188
 
189
 
190
- class FalconAttention(nn.Module):
191
- def __init__(self, config: FalconConfig):
192
  super().__init__()
193
 
194
  self.hidden_size = config.hidden_size
195
- self.num_heads = config.num_attention_heads
196
  self.head_dim = self.hidden_size // self.num_heads
197
  self.split_size = self.hidden_size
198
  self.hidden_dropout = config.hidden_dropout
@@ -203,27 +167,26 @@ class FalconAttention(nn.Module):
203
  f" {self.num_heads})."
204
  )
205
 
206
- self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
 
208
  # Layer-wise attention scaling
209
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
  self.beta = self.inv_norm_factor
211
- if config.new_decoder_architecture:
212
- qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
- elif config.multi_query:
214
- qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
- else:
216
- qkv_out_dim = 3 * self.hidden_size
217
- self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
- self.new_decoder_architecture = config.new_decoder_architecture
219
  self.multi_query = config.multi_query
220
- self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
  self.attention_dropout = nn.Dropout(config.attention_dropout)
222
- self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
 
224
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
  """
226
- Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
227
 
228
  Args:
229
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
@@ -232,18 +195,7 @@ class FalconAttention(nn.Module):
232
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
  value: [batch_size, seq_length, num_heads, head_dim]
234
  """
235
- if self.new_decoder_architecture:
236
- batch, seq_len, _ = fused_qkv.shape
237
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
- query = qkv[:, :, :, :-2]
239
- key = qkv[:, :, :, [-2]]
240
- value = qkv[:, :, :, [-1]]
241
- key = torch.broadcast_to(key, query.shape)
242
- value = torch.broadcast_to(value, query.shape)
243
-
244
- query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
- return query, key, value
246
- elif not self.multi_query:
247
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
@@ -252,13 +204,12 @@ class FalconAttention(nn.Module):
252
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
 
255
- # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
  """
258
  Merge heads together over the last dimenstion
259
 
260
  Args:
261
- x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
 
263
  Returns:
264
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
@@ -281,7 +232,7 @@ class FalconAttention(nn.Module):
281
  def forward(
282
  self,
283
  hidden_states: torch.Tensor,
284
- alibi: Optional[torch.Tensor],
285
  attention_mask: torch.Tensor,
286
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
  head_mask: Optional[torch.Tensor] = None,
@@ -289,120 +240,105 @@ class FalconAttention(nn.Module):
289
  output_attentions: bool = False,
290
  ):
291
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
- num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
  # 3 x [batch_size, seq_length, num_heads, head_dim]
294
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
 
296
- batch_size, query_length, _, _ = query_layer.shape
297
 
298
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
  key_layer = key_layer.transpose(1, 2).reshape(
300
- batch_size * num_kv_heads,
301
- query_length,
302
  self.head_dim,
303
  )
304
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
 
306
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
 
309
  if layer_past is not None:
310
  past_key, past_value = layer_past
311
  # concatenate along seq_length dimension:
312
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
  key_layer = torch.cat((past_key, key_layer), dim=1)
315
  value_layer = torch.cat((past_value, value_layer), dim=1)
316
 
317
  _, kv_length, _ = key_layer.shape
318
- if use_cache:
 
319
  present = (key_layer, value_layer)
320
  else:
321
  present = None
322
 
323
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
-
325
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
- key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
- value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
-
329
  if alibi is None:
330
- if output_attentions:
331
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
- # to do it by hand if we want them
333
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
- attention_scores /= math.sqrt(self.head_dim)
335
 
336
- attention_scores = F.softmax(
337
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
- )
339
- attn_output = attention_scores @ value_layer_
340
- else:
341
- attn_output = F.scaled_dot_product_attention(
342
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
- )
344
- attention_scores = None
345
 
346
- attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
- attn_output = attn_output.permute(0, 2, 1, 3)
348
- attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
 
350
  output_tensor = self.dense(attn_output)
351
 
352
- if output_attentions:
353
- return output_tensor, present, attention_scores
354
- else:
355
- return output_tensor, present
356
-
357
  else:
358
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
359
 
360
  # change view to [batch_size, num_heads, q_length, kv_length]
361
- attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
 
363
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
  input_dtype = attention_scores.dtype
365
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
  attention_scores = attention_scores.to(torch.float32)
368
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
- # and you'd like to experiment and maybe file a PR, feel free!
372
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
- attention_logits *= self.inv_norm_factor
374
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
  # [batch_size, num_heads, q_length, kv_length]
376
  attention_probs = self.attention_dropout(attention_probs)
377
 
378
  if head_mask is not None:
379
  attention_probs = attention_probs * head_mask
380
 
381
- # change view [batch_size, num_heads, q_length, kv_length]
382
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
 
384
  # matmul: [batch_size * num_heads, q_length, head_dim]
385
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
 
387
  # change view [batch_size, num_heads, q_length, head_dim]
388
  context_layer = self._merge_heads(context_layer)
389
 
390
  output_tensor = self.dense(context_layer)
391
 
 
392
  if output_attentions:
393
- return output_tensor, present, attention_probs
394
- else:
395
- return output_tensor, present
396
 
397
 
398
- class FalconMLP(nn.Module):
399
- def __init__(self, config: FalconConfig):
400
  super().__init__()
401
  hidden_size = config.hidden_size
402
 
403
- self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
  self.act = nn.GELU()
405
- self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
  self.hidden_dropout = config.hidden_dropout
407
 
408
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -411,47 +347,43 @@ class FalconMLP(nn.Module):
411
  return x
412
 
413
 
414
- class FalconDecoderLayer(nn.Module):
415
- def __init__(self, config: FalconConfig):
416
  super().__init__()
417
  hidden_size = config.hidden_size
418
- self.num_heads = config.num_attention_heads
419
- self.self_attention = FalconAttention(config)
420
- self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
 
 
421
  self.hidden_dropout = config.hidden_dropout
422
- self.config = config
423
 
424
- if config.new_decoder_architecture:
425
- # The layer norm before self-attention
426
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
- # The layer norm before the MLP
428
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
- else:
430
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
- if not config.parallel_attn:
432
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
 
434
  def forward(
435
  self,
436
  hidden_states: torch.Tensor,
437
- alibi: Optional[torch.Tensor],
438
  attention_mask: torch.Tensor,
439
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
  head_mask: Optional[torch.Tensor] = None,
441
  use_cache: bool = False,
442
  output_attentions: bool = False,
443
  ):
444
- residual = hidden_states
445
 
446
- if self.config.new_decoder_architecture:
447
- attention_layernorm_out = self.ln_attn(hidden_states)
448
- mlp_layernorm_out = self.ln_mlp(hidden_states)
449
- else:
450
- attention_layernorm_out = self.input_layernorm(hidden_states)
451
 
452
  # Self attention.
453
  attn_outputs = self.self_attention(
454
- attention_layernorm_out,
455
  layer_past=layer_past,
456
  attention_mask=attention_mask,
457
  alibi=alibi,
@@ -462,21 +394,16 @@ class FalconDecoderLayer(nn.Module):
462
 
463
  attention_output = attn_outputs[0]
464
 
465
- if not self.config.new_decoder_architecture:
466
- if self.config.parallel_attn:
467
- mlp_layernorm_out = attention_layernorm_out
468
- else:
469
- residual = dropout_add(
470
- attention_output, residual, self.config.attention_dropout, training=self.training
471
- )
472
- mlp_layernorm_out = self.post_attention_layernorm(residual)
473
 
474
  outputs = attn_outputs[1:]
475
 
476
  # MLP.
477
- mlp_output = self.mlp(mlp_layernorm_out)
478
 
479
- if self.config.new_decoder_architecture or self.config.parallel_attn:
480
  mlp_output += attention_output
481
 
482
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
@@ -489,93 +416,24 @@ class FalconDecoderLayer(nn.Module):
489
  return outputs # hidden_states, present, attentions
490
 
491
 
492
- FALCON_START_DOCSTRING = r"""
493
-
494
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
-
497
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
- and behavior.
500
-
501
- Parameters:
502
- config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
503
- Initializing with a config file does not load the weights associated with the model, only the
504
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
- """
506
-
507
- FALCON_INPUTS_DOCSTRING = r"""
508
- Args:
509
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
- (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
-
513
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
- `input_ids`.
515
-
516
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
- [`PreTrainedTokenizer.__call__`] for details.
518
-
519
- [What are input IDs?](../glossary#input-ids)
520
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
- their past given to this model should not be passed as `input_ids` as they have already been computed.
524
-
525
- Each element of `past_key_values` is a tuple (past_key, past_value):
526
- - past_key: [batch_size * num_heads, head_dim, kv_length]
527
- - past_value: [batch_size * num_heads, kv_length, head_dim]
528
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
-
531
- - 1 for tokens that are **not masked**,
532
- - 0 for tokens that are **masked**.
533
-
534
- [What are attention masks?](../glossary#attention-mask)
535
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
-
538
- - 1 indicates the head is **not masked**,
539
- - 0 indicates the head is **masked**.
540
-
541
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
- model's internal embedding lookup matrix.
545
-
546
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
- `past_key_values`).
548
- use_cache (`bool`, *optional*):
549
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
- `past_key_values`).
551
- output_attentions (`bool`, *optional*):
552
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
- tensors for more detail.
554
- output_hidden_states (`bool`, *optional*):
555
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
- more detail.
557
- return_dict (`bool`, *optional*):
558
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
- """
560
-
561
-
562
- class FalconPreTrainedModel(PreTrainedModel):
563
  """
564
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
  models.
566
  """
567
 
568
- config_class = FalconConfig
569
  base_model_prefix = "transformer"
570
  supports_gradient_checkpointing = True
571
- _no_split_modules = ["FalconDecoderLayer"]
572
 
573
  def __init__(self, *inputs, **kwargs):
574
  super().__init__(*inputs, **kwargs)
575
 
576
  def _init_weights(self, module: nn.Module):
577
  """Initialize the weights."""
578
- if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
  # Slightly different from the TF version which uses truncated_normal for initialization
580
  # cf https://github.com/pytorch/pytorch/pull/5617
581
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -589,28 +447,26 @@ class FalconPreTrainedModel(PreTrainedModel):
589
  module.bias.data.zero_()
590
  module.weight.data.fill_(1.0)
591
 
592
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
593
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
- if isinstance(module, FalconModel):
595
  module.gradient_checkpointing = value
596
 
597
  @staticmethod
598
- def _convert_cache_to_standard_format(
599
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
  """
602
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
  num_heads, ...]))
604
  """
605
- batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
- # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
- # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
- # on whether we use multi_query attention.
609
  num_heads = batch_size_times_num_heads // batch_size
 
 
610
  return tuple(
611
  (
612
- layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
- layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
  )
615
  for layer_past in past_key_value
616
  )
@@ -619,35 +475,32 @@ class FalconPreTrainedModel(PreTrainedModel):
619
  def _convert_to_rw_cache(
620
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
- batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
  batch_size_times_num_heads = batch_size * num_heads
624
- # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
625
  return tuple(
626
  (
627
- layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
- layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
  )
630
  for layer_past in past_key_value
631
  )
632
 
633
 
634
- @add_start_docstrings(
635
- "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
- FALCON_START_DOCSTRING,
637
- )
638
- class FalconModel(FalconPreTrainedModel):
639
- def __init__(self, config: FalconConfig):
640
  super().__init__(config)
641
 
642
  self.embed_dim = config.hidden_size
643
- self.num_heads = config.num_attention_heads
644
- self.use_alibi = config.alibi
645
 
646
  # Embedding + LN Embedding
647
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
 
649
  # Transformer blocks
650
- self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
 
652
  # Final Layer Norm
653
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -660,31 +513,22 @@ class FalconModel(FalconPreTrainedModel):
660
  def get_input_embeddings(self):
661
  return self.word_embeddings
662
 
663
- @staticmethod
664
  def _prepare_attn_mask(
665
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
  ) -> torch.BoolTensor:
667
- # Create a causal mask
668
- # The attention mask we receive as input should cover the whole extended sequence, including any past
669
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
- raise ValueError(
673
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
- f" {past_key_values_length}."
676
- )
677
  combined_attention_mask = None
678
  device = attention_mask.device
679
- _, seq_length = input_shape
680
 
681
- if seq_length > 1:
682
  combined_attention_mask = _make_causal_mask(
683
  input_shape, device=device, past_key_values_length=past_key_values_length
684
  )
685
 
686
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
  combined_attention_mask = (
689
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
  )
@@ -694,12 +538,6 @@ class FalconModel(FalconPreTrainedModel):
694
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
  self.word_embeddings = new_embeddings
696
 
697
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
- @add_code_sample_docstrings(
699
- checkpoint=_CHECKPOINT_FOR_DOC,
700
- output_type=BaseModelOutputWithPastAndCrossAttentions,
701
- config_class=_CONFIG_FOR_DOC,
702
- )
703
  def forward(
704
  self,
705
  input_ids: Optional[torch.LongTensor] = None,
@@ -711,7 +549,18 @@ class FalconModel(FalconPreTrainedModel):
711
  output_attentions: Optional[bool] = None,
712
  output_hidden_states: Optional[bool] = None,
713
  return_dict: Optional[bool] = None,
 
714
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
715
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
  output_hidden_states = (
717
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -730,14 +579,12 @@ class FalconModel(FalconPreTrainedModel):
730
 
731
  if past_key_values is None:
732
  past_key_values = tuple([None] * len(self.h))
733
- else:
734
- past_key_values = self._convert_to_rw_cache(past_key_values)
735
 
736
  # Prepare head mask if needed
737
  # 1.0 in head_mask indicate we keep the head
738
  # attention_probs has shape batch_size x num_heads x N x N
739
  # head_mask has shape n_layer x batch x num_heads x N x N
740
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
 
742
  if inputs_embeds is None:
743
  inputs_embeds = self.word_embeddings(input_ids)
@@ -749,15 +596,17 @@ class FalconModel(FalconPreTrainedModel):
749
  all_hidden_states = () if output_hidden_states else None
750
 
751
  # Compute alibi tensor: check build_alibi_tensor documentation
 
752
  past_key_values_length = 0
753
  if past_key_values[0] is not None:
754
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
755
  if attention_mask is None:
756
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
  else:
758
  attention_mask = attention_mask.to(hidden_states.device)
759
 
760
- if self.use_alibi:
761
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
  else:
763
  alibi = None
@@ -769,10 +618,12 @@ class FalconModel(FalconPreTrainedModel):
769
  )
770
 
771
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
772
  if output_hidden_states:
773
  all_hidden_states = all_hidden_states + (hidden_states,)
774
 
775
  if self.gradient_checkpointing and self.training:
 
776
  if use_cache:
777
  logger.warning(
778
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -817,9 +668,6 @@ class FalconModel(FalconPreTrainedModel):
817
  if output_hidden_states:
818
  all_hidden_states = all_hidden_states + (hidden_states,)
819
 
820
- if presents is not None:
821
- presents = self._convert_cache_to_standard_format(presents, batch_size)
822
-
823
  if not return_dict:
824
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
 
@@ -831,16 +679,12 @@ class FalconModel(FalconPreTrainedModel):
831
  )
832
 
833
 
834
- @add_start_docstrings(
835
- "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
- FALCON_START_DOCSTRING,
837
- )
838
- class FalconForCausalLM(FalconPreTrainedModel):
839
- _tied_weights_keys = ["lm_head.weight"]
840
 
841
- def __init__(self, config: FalconConfig):
842
  super().__init__(config)
843
- self.transformer = FalconModel(config)
844
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
845
 
846
  # Initialize weights and apply final processing
@@ -855,26 +699,25 @@ class FalconForCausalLM(FalconPreTrainedModel):
855
  def prepare_inputs_for_generation(
856
  self,
857
  input_ids: torch.LongTensor,
858
- past_key_values: Optional[torch.Tensor] = None,
859
  attention_mask: Optional[torch.Tensor] = None,
860
  **kwargs,
861
  ) -> dict:
862
- if past_key_values is not None:
863
- input_ids = input_ids[:, -1:]
 
 
 
 
 
864
 
865
  return {
866
  "input_ids": input_ids,
867
- "past_key_values": past_key_values,
868
  "use_cache": kwargs.get("use_cache"),
869
  "attention_mask": attention_mask,
870
  }
871
 
872
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
- @add_code_sample_docstrings(
874
- checkpoint=_CHECKPOINT_FOR_DOC,
875
- output_type=CausalLMOutputWithCrossAttentions,
876
- config_class=_CONFIG_FOR_DOC,
877
- )
878
  def forward(
879
  self,
880
  input_ids: Optional[torch.LongTensor] = None,
@@ -887,6 +730,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
887
  output_attentions: Optional[bool] = None,
888
  output_hidden_states: Optional[bool] = None,
889
  return_dict: Optional[bool] = None,
 
890
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
  r"""
892
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -894,6 +738,15 @@ class FalconForCausalLM(FalconPreTrainedModel):
894
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
  """
 
 
 
 
 
 
 
 
 
897
 
898
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
 
@@ -946,6 +799,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
946
 
947
  Output shares the same memory storage as `past`.
948
  """
 
949
 
950
  # Get a copy of `beam_idx` on all the devices where we need those indices.
951
  device_to_beam_idx = {
@@ -956,42 +810,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
956
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
  )
959
- for layer_past in past
960
  )
961
- return reordered_past
962
 
963
 
964
- @add_start_docstrings(
965
- """
966
- The Falcon Model transformer with a sequence classification head on top (linear layer).
967
-
968
- [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
- (e.g. GPT-1) do.
970
-
971
- Since it does classification on the last token, it requires to know the position of the last token. If a
972
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
- each row of the batch).
976
- """,
977
- FALCON_START_DOCSTRING,
978
- )
979
- class FalconForSequenceClassification(FalconPreTrainedModel):
980
- def __init__(self, config: FalconConfig):
981
  super().__init__(config)
982
  self.num_labels = config.num_labels
983
- self.transformer = FalconModel(config)
984
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
985
 
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
989
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
- @add_code_sample_docstrings(
991
- checkpoint=_CHECKPOINT_FOR_DOC,
992
- output_type=SequenceClassifierOutputWithPast,
993
- config_class=_CONFIG_FOR_DOC,
994
- )
995
  def forward(
996
  self,
997
  input_ids: Optional[torch.LongTensor] = None,
@@ -1004,6 +839,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
 
1007
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
  r"""
1009
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1011,6 +847,15 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1011
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
  """
 
 
 
 
 
 
 
 
 
1014
 
1015
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
 
@@ -1085,22 +930,17 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1085
  )
1086
 
1087
 
1088
- @add_start_docstrings(
1089
- """
1090
- Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
- Named-Entity-Recognition (NER) tasks.
1092
- """,
1093
- FALCON_START_DOCSTRING,
1094
- )
1095
- class FalconForTokenClassification(FalconPreTrainedModel):
1096
- def __init__(self, config: FalconConfig):
1097
  super().__init__(config)
1098
  self.num_labels = config.num_labels
1099
 
1100
- self.transformer = FalconModel(config)
1101
- if getattr(config, "classifier_dropout", None) is not None:
1102
  classifier_dropout = config.classifier_dropout
1103
- elif getattr(config, "hidden_dropout", None) is not None:
1104
  classifier_dropout = config.hidden_dropout
1105
  else:
1106
  classifier_dropout = 0.1
@@ -1110,12 +950,6 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1110
  # Initialize weights and apply final processing
1111
  self.post_init()
1112
 
1113
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
- @add_code_sample_docstrings(
1115
- checkpoint=_CHECKPOINT_FOR_DOC,
1116
- output_type=TokenClassifierOutput,
1117
- config_class=_CONFIG_FOR_DOC,
1118
- )
1119
  def forward(
1120
  self,
1121
  input_ids: Optional[torch.LongTensor] = None,
@@ -1128,6 +962,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1128
  output_attentions: Optional[bool] = None,
1129
  output_hidden_states: Optional[bool] = None,
1130
  return_dict: Optional[bool] = None,
 
1131
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
  r"""
1133
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1135,6 +970,15 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1135
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
  """
 
 
 
 
 
 
 
 
 
1138
 
1139
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
 
@@ -1158,9 +1002,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1158
  if labels is not None:
1159
  batch_size, seq_length = labels.shape
1160
  loss_fct = CrossEntropyLoss()
1161
- loss = loss_fct(
1162
- logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
- )
1164
 
1165
  if not return_dict:
1166
  output = (logits,) + transformer_outputs[2:]
@@ -1174,27 +1016,22 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1174
  )
1175
 
1176
 
1177
- @add_start_docstrings(
1178
- """
1179
- The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
- SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
- """,
1182
- FALCON_START_DOCSTRING,
1183
- )
1184
- class FalconForQuestionAnswering(FalconPreTrainedModel):
1185
  def __init__(self, config):
1186
  super().__init__(config)
1187
- self.transformer = FalconModel(config)
1188
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1189
 
1190
  # Initialize weights and apply final processing
1191
  self.post_init()
1192
 
1193
- @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
  def forward(
1195
  self,
1196
  input_ids: Optional[torch.LongTensor] = None,
1197
  attention_mask: Optional[torch.FloatTensor] = None,
 
1198
  head_mask: Optional[torch.FloatTensor] = None,
1199
  inputs_embeds: Optional[torch.FloatTensor] = None,
1200
  start_positions: Optional[torch.LongTensor] = None,
@@ -1218,6 +1055,7 @@ class FalconForQuestionAnswering(FalconPreTrainedModel):
1218
  outputs = self.transformer(
1219
  input_ids,
1220
  attention_mask=attention_mask,
 
1221
  head_mask=head_mask,
1222
  inputs_embeds=inputs_embeds,
1223
  output_attentions=output_attentions,
 
1
+ # port of models described in RW
2
+ # We use the bloom model as a starting point for these model.
3
+ # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
+ import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
 
20
  TokenClassifierOutput,
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
+ class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ ret = input @ self.weight.T
33
  if self.bias is None:
34
+ return ret
35
+ else:
36
+ return ret + self.bias
37
+
38
 
39
+ from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
+ class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
+ This implementation is design to operate on queries and keys that are compatible with
50
+ [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
+ def __init__(
54
+ self,
55
+ head_dim: int,
56
+ base=10000,
57
+ ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
+ self.seq_len_cached = None
63
+ self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
+ def cos_sin(
68
+ self,
69
+ seq_len: int,
70
+ device="cuda",
71
+ dtype=torch.bfloat16,
72
+ ) -> torch.Tensor:
73
+ if seq_len != self.seq_len_cached:
74
+ self.seq_len_cached = seq_len
75
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
 
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
+ return self.cos_cached, self.sin_cached
 
 
 
89
 
90
+ def forward(self, q, k):
91
+ batch, seq_len, head_dim = q.shape
92
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
+ seq_ids = torch.arange(target_length, device=device)
103
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
+
105
+ if past_key_values_length > 0:
106
+ mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
+ batch_size, src_length = mask.shape
114
+ tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
+ class Attention(nn.Module):
155
+ def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
+ self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
 
167
  f" {self.num_heads})."
168
  )
169
 
170
+ self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
+
176
+ self.query_key_value = Linear(
177
+ self.hidden_size,
178
+ 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
+ bias=config.bias,
180
+ )
 
 
181
  self.multi_query = config.multi_query
182
+ self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
  self.attention_dropout = nn.Dropout(config.attention_dropout)
184
+ self.num_kv = config.n_head if not self.multi_query else 1
185
 
186
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
  """
188
+ Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
+ storage as `fused_qkv`
190
 
191
  Args:
192
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 
195
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
+ if not self.multi_query:
 
 
 
 
 
 
 
 
 
 
 
199
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
200
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
201
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
 
204
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
205
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
206
 
 
207
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
208
  """
209
  Merge heads together over the last dimenstion
210
 
211
  Args:
212
+ x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
213
 
214
  Returns:
215
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
235
+ alibi: torch.Tensor,
236
  attention_mask: torch.Tensor,
237
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
238
  head_mask: Optional[torch.Tensor] = None,
 
240
  output_attentions: bool = False,
241
  ):
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
+
244
  # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
+ batch_size, q_length, _, _ = query_layer.shape
248
 
249
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
250
  key_layer = key_layer.transpose(1, 2).reshape(
251
+ batch_size * self.num_kv,
252
+ q_length,
253
  self.head_dim,
254
  )
255
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
  # concatenate along seq_length dimension:
262
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
266
 
267
  _, kv_length, _ = key_layer.shape
268
+
269
+ if use_cache is True:
270
  present = (key_layer, value_layer)
271
  else:
272
  present = None
273
 
 
 
 
 
 
 
274
  if alibi is None:
275
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
278
 
279
+ attn_output = F.scaled_dot_product_attention(
280
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
+ )
 
 
 
 
 
 
282
 
283
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
+ x = x.permute(0, 2, 1, 3)
285
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
286
 
287
  output_tensor = self.dense(attn_output)
288
 
289
+ outputs = (output_tensor, present)
290
+ assert not output_attentions # not supported.
291
+ return outputs
 
 
292
  else:
293
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
294
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]
297
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
298
 
299
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
300
  input_dtype = attention_scores.dtype
301
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
302
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
303
  attention_scores = attention_scores.to(torch.float32)
304
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
+ attention_probs = F.softmax(
306
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
307
+ dim=-1,
308
+ dtype=hidden_states.dtype,
309
+ )
 
310
  # [batch_size, num_heads, q_length, kv_length]
311
  attention_probs = self.attention_dropout(attention_probs)
312
 
313
  if head_mask is not None:
314
  attention_probs = attention_probs * head_mask
315
 
316
+ # change view [batch_size x num_heads, q_length, kv_length]
317
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
318
 
319
  # matmul: [batch_size * num_heads, q_length, head_dim]
320
+ context_layer = attention_probs_reshaped @ value_layer
321
 
322
  # change view [batch_size, num_heads, q_length, head_dim]
323
  context_layer = self._merge_heads(context_layer)
324
 
325
  output_tensor = self.dense(context_layer)
326
 
327
+ outputs = (output_tensor, present)
328
  if output_attentions:
329
+ outputs += (attention_probs,)
330
+
331
+ return outputs
332
 
333
 
334
+ class MLP(nn.Module):
335
+ def __init__(self, config: RWConfig):
336
  super().__init__()
337
  hidden_size = config.hidden_size
338
 
339
+ self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
340
  self.act = nn.GELU()
341
+ self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
342
  self.hidden_dropout = config.hidden_dropout
343
 
344
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
347
  return x
348
 
349
 
350
+ class DecoderLayer(nn.Module):
351
+ def __init__(self, config: RWConfig):
352
  super().__init__()
353
  hidden_size = config.hidden_size
354
+
355
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
356
+ self.num_heads = config.n_head
357
+ self.self_attention = Attention(config)
358
+
359
+ if not config.parallel_attn:
360
+ # unused if parallel attn
361
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
362
+
363
+ self.mlp = MLP(config)
364
+
365
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
366
  self.hidden_dropout = config.hidden_dropout
 
367
 
368
+ self.config = config
 
 
 
 
 
 
 
 
369
 
370
  def forward(
371
  self,
372
  hidden_states: torch.Tensor,
373
+ alibi: torch.Tensor,
374
  attention_mask: torch.Tensor,
375
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
  head_mask: Optional[torch.Tensor] = None,
377
  use_cache: bool = False,
378
  output_attentions: bool = False,
379
  ):
 
380
 
381
+ layernorm_output = self.input_layernorm(hidden_states)
382
+ residual = hidden_states
 
 
 
383
 
384
  # Self attention.
385
  attn_outputs = self.self_attention(
386
+ layernorm_output,
387
  layer_past=layer_past,
388
  attention_mask=attention_mask,
389
  alibi=alibi,
 
394
 
395
  attention_output = attn_outputs[0]
396
 
397
+ if not self.config.parallel_attn:
398
+ residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
399
+ layernorm_output = self.post_attention_layernorm(residual)
 
 
 
 
 
400
 
401
  outputs = attn_outputs[1:]
402
 
403
  # MLP.
404
+ mlp_output = self.mlp(layernorm_output)
405
 
406
+ if self.config.parallel_attn:
407
  mlp_output += attention_output
408
 
409
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
416
  return outputs # hidden_states, present, attentions
417
 
418
 
419
+ class RWPreTrainedModel(PreTrainedModel):
420
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  """
422
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
  models.
424
  """
425
 
426
+ config_class = RWConfig
427
  base_model_prefix = "transformer"
428
  supports_gradient_checkpointing = True
429
+ _no_split_modules = ["DecoderLayer"]
430
 
431
  def __init__(self, *inputs, **kwargs):
432
  super().__init__(*inputs, **kwargs)
433
 
434
  def _init_weights(self, module: nn.Module):
435
  """Initialize the weights."""
436
+ if isinstance(module, nn.Linear) or isinstance(module, Linear):
437
  # Slightly different from the TF version which uses truncated_normal for initialization
438
  # cf https://github.com/pytorch/pytorch/pull/5617
439
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
447
  module.bias.data.zero_()
448
  module.weight.data.fill_(1.0)
449
 
 
450
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
451
+ if isinstance(module, RWModel):
452
  module.gradient_checkpointing = value
453
 
454
  @staticmethod
455
+ def _convert_to_standard_cache(
456
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
457
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
458
  """
459
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
460
  num_heads, ...]))
461
  """
462
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
463
  num_heads = batch_size_times_num_heads // batch_size
464
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
465
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
466
  return tuple(
467
  (
468
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
469
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
470
  )
471
  for layer_past in past_key_value
472
  )
 
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
  batch_size_times_num_heads = batch_size * num_heads
480
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
483
  (
484
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
485
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
486
  )
487
  for layer_past in past_key_value
488
  )
489
 
490
 
491
+ class RWModel(RWPreTrainedModel):
492
+ def __init__(self, config: RWConfig):
 
 
 
 
493
  super().__init__(config)
494
 
495
  self.embed_dim = config.hidden_size
496
+ self.num_heads = config.n_head
497
+ self.alibi = config.alibi
498
 
499
  # Embedding + LN Embedding
500
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
501
 
502
  # Transformer blocks
503
+ self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
504
 
505
  # Final Layer Norm
506
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
513
  def get_input_embeddings(self):
514
  return self.word_embeddings
515
 
 
516
  def _prepare_attn_mask(
517
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
518
  ) -> torch.BoolTensor:
519
+ # create causal mask
520
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
521
  combined_attention_mask = None
522
  device = attention_mask.device
523
+ _, src_length = input_shape
524
 
525
+ if src_length > 1:
526
  combined_attention_mask = _make_causal_mask(
527
  input_shape, device=device, past_key_values_length=past_key_values_length
528
  )
529
 
530
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
532
  combined_attention_mask = (
533
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
534
  )
 
538
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
539
  self.word_embeddings = new_embeddings
540
 
 
 
 
 
 
 
541
  def forward(
542
  self,
543
  input_ids: Optional[torch.LongTensor] = None,
 
549
  output_attentions: Optional[bool] = None,
550
  output_hidden_states: Optional[bool] = None,
551
  return_dict: Optional[bool] = None,
552
+ **deprecated_arguments,
553
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
554
+ if deprecated_arguments.pop("position_ids", False) is not False:
555
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
556
+ warnings.warn(
557
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
558
+ " passing `position_ids`.",
559
+ FutureWarning,
560
+ )
561
+ if len(deprecated_arguments) > 0:
562
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
+
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
579
 
580
  if past_key_values is None:
581
  past_key_values = tuple([None] * len(self.h))
 
 
582
 
583
  # Prepare head mask if needed
584
  # 1.0 in head_mask indicate we keep the head
585
  # attention_probs has shape batch_size x num_heads x N x N
586
  # head_mask has shape n_layer x batch x num_heads x N x N
587
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
588
 
589
  if inputs_embeds is None:
590
  inputs_embeds = self.word_embeddings(input_ids)
 
596
  all_hidden_states = () if output_hidden_states else None
597
 
598
  # Compute alibi tensor: check build_alibi_tensor documentation
599
+ seq_length_with_past = seq_length
600
  past_key_values_length = 0
601
  if past_key_values[0] is not None:
602
+ past_key_values_length = past_key_values[0][0].shape[2]
603
+ seq_length_with_past = seq_length_with_past + past_key_values_length
604
  if attention_mask is None:
605
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
606
  else:
607
  attention_mask = attention_mask.to(hidden_states.device)
608
 
609
+ if self.alibi:
610
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
611
  else:
612
  alibi = None
 
618
  )
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
+
622
  if output_hidden_states:
623
  all_hidden_states = all_hidden_states + (hidden_states,)
624
 
625
  if self.gradient_checkpointing and self.training:
626
+
627
  if use_cache:
628
  logger.warning(
629
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
668
  if output_hidden_states:
669
  all_hidden_states = all_hidden_states + (hidden_states,)
670
 
 
 
 
671
  if not return_dict:
672
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
673
 
 
679
  )
680
 
681
 
682
+ class RWForCausalLM(RWPreTrainedModel):
683
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
 
 
684
 
685
+ def __init__(self, config: RWConfig):
686
  super().__init__(config)
687
+ self.transformer = RWModel(config)
688
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
689
 
690
  # Initialize weights and apply final processing
 
699
  def prepare_inputs_for_generation(
700
  self,
701
  input_ids: torch.LongTensor,
702
+ past: Optional[torch.Tensor] = None,
703
  attention_mask: Optional[torch.Tensor] = None,
704
  **kwargs,
705
  ) -> dict:
706
+ # only last token for input_ids if past is not None
707
+ if past:
708
+ input_ids = input_ids[:, -1].unsqueeze(-1)
709
+
710
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
+ if past[0][0].shape[0] == input_ids.shape[0]:
712
+ past = self._convert_to_rw_cache(past)
713
 
714
  return {
715
  "input_ids": input_ids,
716
+ "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
720
 
 
 
 
 
 
 
721
  def forward(
722
  self,
723
  input_ids: Optional[torch.LongTensor] = None,
 
730
  output_attentions: Optional[bool] = None,
731
  output_hidden_states: Optional[bool] = None,
732
  return_dict: Optional[bool] = None,
733
+ **deprecated_arguments,
734
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
735
  r"""
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
738
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
739
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
740
  """
741
+ if deprecated_arguments.pop("position_ids", False) is not False:
742
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
743
+ warnings.warn(
744
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
745
+ " passing `position_ids`.",
746
+ FutureWarning,
747
+ )
748
+ if len(deprecated_arguments) > 0:
749
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
750
 
751
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
752
 
 
799
 
800
  Output shares the same memory storage as `past`.
801
  """
802
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
803
 
804
  # Get a copy of `beam_idx` on all the devices where we need those indices.
805
  device_to_beam_idx = {
 
810
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
811
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
812
  )
813
+ for layer_past in standardized_past
814
  )
815
+ return self._convert_to_rw_cache(reordered_past)
816
 
817
 
818
+ class RWForSequenceClassification(RWPreTrainedModel):
819
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
820
+
821
+ def __init__(self, config: RWConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  super().__init__(config)
823
  self.num_labels = config.num_labels
824
+ self.transformer = RWModel(config)
825
  self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
826
 
827
  # Initialize weights and apply final processing
828
  self.post_init()
829
 
 
 
 
 
 
 
830
  def forward(
831
  self,
832
  input_ids: Optional[torch.LongTensor] = None,
 
839
  output_attentions: Optional[bool] = None,
840
  output_hidden_states: Optional[bool] = None,
841
  return_dict: Optional[bool] = None,
842
+ **deprecated_arguments,
843
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
844
  r"""
845
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
847
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
  """
850
+ if deprecated_arguments.pop("position_ids", False) is not False:
851
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
852
+ warnings.warn(
853
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
854
+ " passing `position_ids`.",
855
+ FutureWarning,
856
+ )
857
+ if len(deprecated_arguments) > 0:
858
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
859
 
860
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
861
 
 
930
  )
931
 
932
 
933
+ class RWForTokenClassification(RWPreTrainedModel):
934
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
935
+
936
+ def __init__(self, config: RWConfig):
 
 
 
 
 
937
  super().__init__(config)
938
  self.num_labels = config.num_labels
939
 
940
+ self.transformer = RWModel(config)
941
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
942
  classifier_dropout = config.classifier_dropout
943
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
944
  classifier_dropout = config.hidden_dropout
945
  else:
946
  classifier_dropout = 0.1
 
950
  # Initialize weights and apply final processing
951
  self.post_init()
952
 
 
 
 
 
 
 
953
  def forward(
954
  self,
955
  input_ids: Optional[torch.LongTensor] = None,
 
962
  output_attentions: Optional[bool] = None,
963
  output_hidden_states: Optional[bool] = None,
964
  return_dict: Optional[bool] = None,
965
+ **deprecated_arguments,
966
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
967
  r"""
968
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
970
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
971
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
972
  """
973
+ if deprecated_arguments.pop("position_ids", False) is not False:
974
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
975
+ warnings.warn(
976
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
977
+ " passing `position_ids`.",
978
+ FutureWarning,
979
+ )
980
+ if len(deprecated_arguments) > 0:
981
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
982
 
983
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
 
 
1002
  if labels is not None:
1003
  batch_size, seq_length = labels.shape
1004
  loss_fct = CrossEntropyLoss()
1005
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1006
 
1007
  if not return_dict:
1008
  output = (logits,) + transformer_outputs[2:]
 
1016
  )
1017
 
1018
 
1019
+ class RWForQuestionAnswering(RWPreTrainedModel):
1020
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1021
+
 
 
 
 
 
1022
  def __init__(self, config):
1023
  super().__init__(config)
1024
+ self.transformer = RWModel(config)
1025
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1026
 
1027
  # Initialize weights and apply final processing
1028
  self.post_init()
1029
 
 
1030
  def forward(
1031
  self,
1032
  input_ids: Optional[torch.LongTensor] = None,
1033
  attention_mask: Optional[torch.FloatTensor] = None,
1034
+ position_ids: Optional[torch.LongTensor] = None,
1035
  head_mask: Optional[torch.FloatTensor] = None,
1036
  inputs_embeds: Optional[torch.FloatTensor] = None,
1037
  start_positions: Optional[torch.LongTensor] = None,
 
1055
  outputs = self.transformer(
1056
  input_ids,
1057
  attention_mask=attention_mask,
1058
+ position_ids=position_ids,
1059
  head_mask=head_mask,
1060
  inputs_embeds=inputs_embeds,
1061
  output_attentions=output_attentions,