OpenNLPLab commited on
Commit
460b22e
·
1 Parent(s): cf95141

Fix issues regarding to transformer version

Browse files
generation_config.json CHANGED
@@ -1,6 +1,9 @@
1
  {
2
- "_from_model_config": true,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
- "transformers_version": "4.31.0"
 
 
 
6
  }
 
1
  {
2
+ "pad_token_id": 0,
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
+ "max_new_tokens": 2048,
6
+ "temperature": 1.0,
7
+ "repetition_penalty": 1.03,
8
+ "do_sample": true
9
  }
modeling_transnormer.py CHANGED
@@ -11,8 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
-
15
- # coding=utf-8
16
  """ PyTorch Transnormer model."""
17
  import math
18
  import os
@@ -29,7 +28,6 @@ from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import (
30
  BaseModelOutputWithPast,
31
  CausalLMOutputWithPast,
32
- SequenceClassifierOutputWithPast,
33
  )
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.utils import (
@@ -85,7 +83,6 @@ if not has_lightning_attention:
85
  ########## start Transnormer
86
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
87
  class Lrpe(nn.Module):
88
-
89
  def __init__(
90
  self,
91
  num_heads=8,
@@ -95,8 +92,9 @@ class Lrpe(nn.Module):
95
  d = num_heads * embed_dim
96
 
97
  self.index = torch.empty(0)
98
- self.theta = nn.Parameter(10000**(-2 / d * torch.arange(d)).reshape(
99
- num_heads, 1, -1))
 
100
 
101
  def extra_repr(self):
102
  return print_module(self)
@@ -115,7 +113,6 @@ class Lrpe(nn.Module):
115
 
116
 
117
  class GLU(nn.Module):
118
-
119
  def __init__(self, d1, d2, bias=False):
120
  super().__init__()
121
  if debug:
@@ -138,7 +135,6 @@ class GLU(nn.Module):
138
 
139
 
140
  class NormLinearAttention(nn.Module):
141
-
142
  def __init__(
143
  self,
144
  embed_dim,
@@ -194,7 +190,7 @@ class NormLinearAttention(nn.Module):
194
  output_attentions,
195
  past_key_value,
196
  use_cache,
197
- slope_rate=slope_rate,
198
  )
199
  # x: b n d
200
  n = x.shape[-2]
@@ -202,8 +198,8 @@ class NormLinearAttention(nn.Module):
202
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
203
  # reshape
204
  q, k, v = map(
205
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
206
- [q, k, v])
207
  # act
208
  q = self.act(q)
209
  k = self.act(k)
@@ -211,7 +207,7 @@ class NormLinearAttention(nn.Module):
211
  q_offset = 0
212
  # lrpe relys on position, get cache first
213
  if past_key_value is not None:
214
- # reuse k, v, self_attention
215
  k = torch.cat([past_key_value[0], k], dim=-2)
216
  v = torch.cat([past_key_value[1], v], dim=-2)
217
  q_offset = past_key_value[0].shape[-2]
@@ -228,17 +224,17 @@ class NormLinearAttention(nn.Module):
228
 
229
  if attn_padding_mask is not None:
230
  v = v.masked_fill(
231
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
232
- torch.bool), 0)
233
 
234
  if not has_lightning_attention:
235
  if slope_rate != None:
236
  attn_mask = torch.exp(slope_rate * attn_mask)
237
-
238
  output = linear_attention(q, k, v, attn_mask)
239
  else:
240
- output = lightning_attention(q, k, v, True,
241
- slope_rate.squeeze(-1).squeeze(-1))
 
242
 
243
  # reshape
244
  output = rearrange(output, "b h n d -> b n (h d)")
@@ -257,14 +253,14 @@ class NormLinearAttention(nn.Module):
257
  return output, attn_weights, past_key_value
258
 
259
  def inference(
260
- self,
261
- x,
262
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
263
- attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
264
- output_attentions: bool = False,
265
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
266
- use_cache: bool = False,
267
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
268
  ):
269
  # x: b n d
270
  n = x.shape[-2]
@@ -272,8 +268,8 @@ class NormLinearAttention(nn.Module):
272
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
273
  # reshape
274
  q, k, v = map(
275
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
276
- [q, k, v])
277
  # act
278
  q = self.act(q)
279
  k = self.act(k)
@@ -281,7 +277,7 @@ class NormLinearAttention(nn.Module):
281
  # rpe
282
  if self.linear_use_lrpe:
283
  q = self.lrpe(q, offset=self.offset)
284
- k = self.lrpe(k, offset=self.offset)
285
 
286
  if past_key_value == None:
287
  self.offset = q.shape[-2]
@@ -299,8 +295,7 @@ class NormLinearAttention(nn.Module):
299
 
300
  if attn_padding_mask is not None:
301
  attn_mask = attn_mask.masked_fill(
302
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(2).to(
303
- torch.bool),
304
  0,
305
  )
306
  energy = torch.einsum("... n d, ... m d -> ... n m", q, k)
@@ -311,18 +306,17 @@ class NormLinearAttention(nn.Module):
311
  output = torch.einsum("... n m, ... m d -> ... n d", energy, v)
312
 
313
  eval_and_not_generate = eval(
314
- os.environ.get("eval_and_not_generate", default="False"))
 
315
  if eval_and_not_generate:
316
  kv = None
317
  else:
318
  # b, h, n, e, d
319
- kv_outproduct = torch.einsum("... n e, ... n d -> ... n e d",
320
- k, v)
321
  # 1, 1, n, 1, 1
322
- index = torch.arange(n - 1, -1, -1).reshape(1, 1, -1, 1,
323
- 1).to(x)
324
  # (h, 1, 1) -> (1, h, 1, 1, 1); (1, h, 1, 1, 1), (1, 1, n, 1, 1) -> (1, h, n, 1, 1)
325
- decay = ratio.unsqueeze(0).unsqueeze(-1)**index
326
 
327
  kv_outproduct_with_decay = kv_outproduct * decay
328
  kv = torch.sum(kv_outproduct_with_decay, dim=-3)
@@ -333,11 +327,12 @@ class NormLinearAttention(nn.Module):
333
  for i in range(n):
334
  kv = ratio * kv + torch.einsum(
335
  "... n d, ... n e -> ... d e",
336
- k[:, :, i:i + 1],
337
- v[:, :, i:i + 1],
 
 
 
338
  )
339
- qkv = torch.einsum("... n e, ... e d -> ... n d",
340
- q[:, :, i:i + 1], kv)
341
  output.append(qkv)
342
  output = torch.concat(output, dim=-2)
343
 
@@ -356,7 +351,6 @@ class NormLinearAttention(nn.Module):
356
 
357
 
358
  class TransnormerDecoderLayer(nn.Module):
359
-
360
  def __init__(self, config: TransnormerConfig):
361
  super().__init__()
362
  self.embed_dim = config.decoder_embed_dim
@@ -395,14 +389,14 @@ class TransnormerDecoderLayer(nn.Module):
395
  return residual + x
396
 
397
  def forward(
398
- self,
399
- x,
400
- attn_mask: Optional[torch.Tensor] = None,
401
- attn_padding_mask: Optional[torch.Tensor] = None,
402
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
403
- output_attentions: Optional[bool] = False,
404
- use_cache: Optional[bool] = False,
405
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
406
  ):
407
  residual = x
408
  x = self.token_norm(x)
@@ -422,13 +416,13 @@ class TransnormerDecoderLayer(nn.Module):
422
  x = self.channel_mixer(x)
423
  x = self.residual_connection(x, residual)
424
 
425
- outputs = (x, )
426
 
427
  if output_attentions:
428
- outputs += (self_attn_weights, )
429
 
430
  if use_cache:
431
- outputs += (present_key_value, )
432
 
433
  return outputs
434
 
@@ -450,7 +444,9 @@ TRANSNORMER_START_DOCSTRING = r"""
450
  """
451
 
452
 
453
- @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
454
  class TransnormerPreTrainedModel(PreTrainedModel):
455
  config_class = TransnormerConfig
456
  base_model_prefix = "model"
@@ -535,7 +531,9 @@ TRANSNORMER_INPUTS_DOCSTRING = r"""
535
  """
536
 
537
 
538
- @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
539
  class TransnormerModel(TransnormerPreTrainedModel):
540
  """
541
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
@@ -559,31 +557,29 @@ class TransnormerModel(TransnormerPreTrainedModel):
559
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
560
 
561
  # params
562
- self.embed_tokens = nn.Embedding(config.vocab_size,
563
- config.decoder_embed_dim,
564
- self.padding_idx)
565
  self.layers = nn.ModuleList([])
566
  for i in range(config.decoder_layers):
567
  if len(self.linear_use_lrpe_list) > 0:
568
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
569
  self.layers.append(TransnormerDecoderLayer(config))
570
 
571
- self.final_norm = get_norm_fn(config.norm_type)(
572
- config.decoder_embed_dim)
573
  self.embed_dim = config.decoder_embed_dim
574
- self.embed_scale = (1.0 if config.no_scale_embedding else math.sqrt(
575
- self.embed_dim))
 
576
 
577
  # Initialize weights and apply final processing
578
  self.post_init()
579
 
580
  @staticmethod
581
  def _build_slope_tensor(n_attention_heads: int):
582
-
583
  def get_slopes(n):
584
-
585
  def get_slopes_power_of_2(n):
586
- start = 2**(-(2**-(math.log2(n) - 3)))
587
  ratio = start
588
  return [start * ratio**i for i in range(n)]
589
 
@@ -592,15 +588,18 @@ class TransnormerModel(TransnormerPreTrainedModel):
592
  n
593
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
594
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
595
- closest_power_of_2 = 2**math.floor(
596
  math.log2(n)
597
  ) # when the number of heads is not a power of 2, we use this workaround.
598
- return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
599
- 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
 
 
600
 
601
  # h, 1, 1
602
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
603
- n_attention_heads, 1, 1)
 
604
 
605
  return slopes
606
 
@@ -613,26 +612,26 @@ class TransnormerModel(TransnormerPreTrainedModel):
613
  def set_input_embeddings(self, value):
614
  self.embed_tokens = value
615
 
616
- def _prepare_decoder_linear_attn_mask(self, input_shape, inputs_embeds,
617
- past_key_values_length):
 
618
  bsz, tgt_len = input_shape
619
  src_len = tgt_len + past_key_values_length
620
 
621
  def power_log(x):
622
- return 2**(math.ceil(math.log(x, 2)))
623
 
624
  n = power_log(max(tgt_len, src_len))
625
  if self._linear_attn_mask.shape[-1] < n:
626
 
627
  def get_mask(n):
628
- mask = torch.triu(
629
- torch.zeros(n, n).float().fill_(float("-inf")), 1)
630
  # no slope version
631
  # -n, ..., -2, -1, 0
632
  for i in range(n):
633
  x = torch.arange(i + 1)
634
  y = x
635
- mask[i, :i + 1] = -torch.flip(y, [0])
636
 
637
  return mask
638
 
@@ -644,8 +643,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
644
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
645
  num_heads = linear_attn_mask.shape[0]
646
 
647
- return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len,
648
- src_len)
649
 
650
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
651
  def forward(
@@ -659,15 +657,21 @@ class TransnormerModel(TransnormerPreTrainedModel):
659
  output_hidden_states: Optional[bool] = None,
660
  return_dict: Optional[bool] = None,
661
  ) -> Union[Tuple, BaseModelOutputWithPast]:
662
- output_attentions = (output_attentions if output_attentions is not None
663
- else self.config.output_attentions)
664
- output_hidden_states = (output_hidden_states
665
- if output_hidden_states is not None else
666
- self.config.output_hidden_states)
 
 
 
 
 
667
  use_cache = use_cache if use_cache is not None else self.config.use_cache
668
 
669
- return_dict = (return_dict if return_dict is not None else
670
- self.config.use_return_dict)
 
671
 
672
  # retrieve input_ids and inputs_embeds
673
  if input_ids is not None and inputs_embeds is not None:
@@ -689,7 +693,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
689
  if past_key_values is not None:
690
  past_key_values_length = past_key_values[0][0].shape[-2]
691
  seq_length_with_past = seq_length_with_past + past_key_values_length
692
-
693
  if inputs_embeds is None:
694
  # !!! use embed_scale
695
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
@@ -711,72 +715,54 @@ class TransnormerModel(TransnormerPreTrainedModel):
711
  ##### norm linear layers
712
  linear_attn_padding_mask = attn_padding_mask
713
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
714
- (batch_size, seq_length), inputs_embeds, past_key_values_length)
 
715
 
716
- slope_rates = [
717
- self.slopes.to(input_ids.device) for _ in range(self.num_layers)
718
- ]
719
 
720
  for idx, layer in enumerate(self.layers):
721
  if output_hidden_states:
722
- all_hidden_states += (hidden_states, )
723
 
724
- past_key_value = (past_key_values[idx]
725
- if past_key_values is not None else None)
 
726
 
727
  slope_rate = slope_rates[idx]
728
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
729
  mask = linear_attn_mask
730
-
731
- if self.gradient_checkpointing and self.training:
732
-
733
- def create_custom_forward(module):
734
-
735
- def custom_forward(*inputs):
736
- # None for past_key_value
737
- return module(*inputs, output_attentions, None)
738
-
739
- return custom_forward
740
-
741
- layer_outputs = torch.utils.checkpoint.checkpoint(
742
- create_custom_forward(layer),
743
- hidden_states,
744
- mask,
745
- linear_attn_padding_mask,
746
- None,
747
- )
748
- else:
749
- layer_outputs = layer(
750
- hidden_states,
751
- attn_mask=mask,
752
- attn_padding_mask=linear_attn_padding_mask,
753
- past_key_value=past_key_value,
754
- output_attentions=output_attentions,
755
- use_cache=use_cache,
756
- slope_rate=slope_rate,
757
- )
758
 
759
  hidden_states = layer_outputs[0]
760
 
761
  if use_cache:
762
- next_decoder_cache += (
763
- layer_outputs[2 if output_attentions else 1], )
764
 
765
  if output_attentions:
766
- all_self_attns += (layer_outputs[1], )
767
 
768
  hidden_states = self.final_norm(hidden_states)
769
 
770
  # add hidden states from the last decoder layer
771
  if output_hidden_states:
772
- all_hidden_states += (hidden_states, )
773
 
774
  next_cache = next_decoder_cache if use_cache else None
775
  if not return_dict:
776
  return tuple(
777
- v for v in
778
- [hidden_states, next_cache, all_hidden_states, all_self_attns]
779
- if v is not None)
 
780
  return BaseModelOutputWithPast(
781
  last_hidden_state=hidden_states,
782
  past_key_values=next_cache,
@@ -786,7 +772,6 @@ class TransnormerModel(TransnormerPreTrainedModel):
786
 
787
 
788
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
789
-
790
  def __init__(self, config):
791
  super().__init__(config)
792
  self.model = TransnormerModel(config)
@@ -794,9 +779,9 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
794
  logging_info(self.model)
795
 
796
  # the lm_head weight is automatically tied to the embed tokens weight
797
- self.lm_head = nn.Linear(config.decoder_embed_dim,
798
- config.vocab_size,
799
- bias=False)
800
 
801
  # Initialize weights and apply final processing
802
  self.post_init()
@@ -820,8 +805,9 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
820
  return self.model
821
 
822
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
823
- @replace_return_docstrings(output_type=CausalLMOutputWithPast,
824
- config_class=_CONFIG_FOR_DOC)
 
825
  def forward(
826
  self,
827
  input_ids: torch.LongTensor = None,
@@ -859,13 +845,19 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
859
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
860
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
861
  ```"""
862
- output_attentions = (output_attentions if output_attentions is not None
863
- else self.config.output_attentions)
864
- output_hidden_states = (output_hidden_states
865
- if output_hidden_states is not None else
866
- self.config.output_hidden_states)
867
- return_dict = (return_dict if return_dict is not None else
868
- self.config.use_return_dict)
 
 
 
 
 
 
869
 
870
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
871
  outputs = self.model(
@@ -896,8 +888,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
896
  loss = loss_fct(shift_logits, shift_labels)
897
 
898
  if not return_dict:
899
- output = (logits, ) + outputs[1:]
900
- return (loss, ) + output if loss is not None else output
901
 
902
  return CausalLMOutputWithPast(
903
  loss=loss,
@@ -924,149 +916,23 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
924
  else:
925
  model_inputs = {"input_ids": input_ids}
926
 
927
- model_inputs.update({
928
- "past_key_values": past_key_values,
929
- "use_cache": kwargs.get("use_cache"),
930
- "attention_mask": attention_mask,
931
- })
 
 
932
  return model_inputs
933
 
934
  @staticmethod
935
  def _reorder_cache(past_key_values, beam_idx):
936
  reordered_past = ()
937
  for layer_past in past_key_values:
938
- reordered_past += (tuple(
939
- past_state.index_select(0, beam_idx)
940
- for past_state in layer_past), )
941
- return reordered_past
942
-
943
-
944
- @add_start_docstrings(
945
- """
946
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
947
-
948
- [`TransnormerForSequenceClassification`] uses the last token in order to do the classification, as other causal models
949
- (e.g. GPT-2) do.
950
-
951
- Since it does classification on the last token, it requires to know the position of the last token. If a
952
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
953
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
954
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
955
- each row of the batch).
956
- """,
957
- TRANSNORMER_START_DOCSTRING,
958
- )
959
- class TransnormerForSequenceClassification(TransnormerPreTrainedModel):
960
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
961
-
962
- def __init__(self, config):
963
- super().__init__(config)
964
- self.num_labels = config.num_labels
965
- self.model = TransnormerModel(config)
966
- self.score = nn.Linear(config.decoder_embed_dim,
967
- self.num_labels,
968
- bias=False)
969
-
970
- # Initialize weights and apply final processing
971
- self.post_init()
972
-
973
- def get_input_embeddings(self):
974
- return self.model.embed_tokens
975
-
976
- def set_input_embeddings(self, value):
977
- self.model.embed_tokens = value
978
-
979
- @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
980
- def forward(
981
- self,
982
- input_ids: torch.LongTensor = None,
983
- attn_mask: Optional[torch.Tensor] = None,
984
- past_key_values: Optional[List[torch.FloatTensor]] = None,
985
- inputs_embeds: Optional[torch.FloatTensor] = None,
986
- labels: Optional[torch.LongTensor] = None,
987
- use_cache: Optional[bool] = None,
988
- output_attentions: Optional[bool] = None,
989
- output_hidden_states: Optional[bool] = None,
990
- return_dict: Optional[bool] = None,
991
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
992
- r"""
993
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
994
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
995
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
996
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
997
- """
998
- return_dict = (return_dict if return_dict is not None else
999
- self.config.use_return_dict)
1000
-
1001
- transformer_outputs = self.model(
1002
- input_ids,
1003
- attn_padding_mask=attn_mask,
1004
- past_key_values=past_key_values,
1005
- inputs_embeds=inputs_embeds,
1006
- use_cache=use_cache,
1007
- output_attentions=output_attentions,
1008
- output_hidden_states=output_hidden_states,
1009
- return_dict=return_dict,
1010
- )
1011
- hidden_states = transformer_outputs[0]
1012
-
1013
- logits = self.score(hidden_states)
1014
-
1015
- if input_ids is not None:
1016
- batch_size = input_ids.shape[0]
1017
- else:
1018
- batch_size = inputs_embeds.shape[0]
1019
-
1020
- if self.config.pad_token_id is None and batch_size != 1:
1021
- raise ValueError(
1022
- "Cannot handle batch sizes > 1 if no padding token is defined."
1023
  )
1024
- if self.config.pad_token_id is None:
1025
- sequence_lengths = -1
1026
- else:
1027
- if input_ids is not None:
1028
- sequence_lengths = (
1029
- torch.ne(input_ids, self.config.pad_token_id).sum(-1) -
1030
- 1).to(logits.device)
1031
- else:
1032
- sequence_lengths = -1
1033
-
1034
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
1035
- sequence_lengths]
1036
-
1037
- loss = None
1038
- if labels is not None:
1039
- labels = labels.to(logits.device)
1040
- if self.config.problem_type is None:
1041
- if self.num_labels == 1:
1042
- self.config.problem_type = "regression"
1043
- elif self.num_labels > 1 and (labels.dtype == torch.long
1044
- or labels.dtype == torch.int):
1045
- self.config.problem_type = "single_label_classification"
1046
- else:
1047
- self.config.problem_type = "multi_label_classification"
1048
-
1049
- if self.config.problem_type == "regression":
1050
- loss_fct = MSELoss()
1051
- if self.num_labels == 1:
1052
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1053
- else:
1054
- loss = loss_fct(pooled_logits, labels)
1055
- elif self.config.problem_type == "single_label_classification":
1056
- loss_fct = CrossEntropyLoss()
1057
- loss = loss_fct(pooled_logits.view(-1, self.num_labels),
1058
- labels.view(-1))
1059
- elif self.config.problem_type == "multi_label_classification":
1060
- loss_fct = BCEWithLogitsLoss()
1061
- loss = loss_fct(pooled_logits, labels)
1062
- if not return_dict:
1063
- output = (pooled_logits, ) + transformer_outputs[1:]
1064
- return ((loss, ) + output) if loss is not None else output
1065
 
1066
- return SequenceClassifierOutputWithPast(
1067
- loss=loss,
1068
- logits=pooled_logits,
1069
- past_key_values=transformer_outputs.past_key_values,
1070
- hidden_states=transformer_outputs.hidden_states,
1071
- attentions=transformer_outputs.attentions,
1072
- )
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ # coding=utf-8
 
15
  """ PyTorch Transnormer model."""
16
  import math
17
  import os
 
28
  from transformers.modeling_outputs import (
29
  BaseModelOutputWithPast,
30
  CausalLMOutputWithPast,
 
31
  )
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import (
 
83
  ########## start Transnormer
84
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
85
  class Lrpe(nn.Module):
 
86
  def __init__(
87
  self,
88
  num_heads=8,
 
92
  d = num_heads * embed_dim
93
 
94
  self.index = torch.empty(0)
95
+ self.theta = nn.Parameter(
96
+ 10000 ** (-2 / d * torch.arange(d)).reshape(num_heads, 1, -1)
97
+ )
98
 
99
  def extra_repr(self):
100
  return print_module(self)
 
113
 
114
 
115
  class GLU(nn.Module):
 
116
  def __init__(self, d1, d2, bias=False):
117
  super().__init__()
118
  if debug:
 
135
 
136
 
137
  class NormLinearAttention(nn.Module):
 
138
  def __init__(
139
  self,
140
  embed_dim,
 
190
  output_attentions,
191
  past_key_value,
192
  use_cache,
193
+ slope_rate,
194
  )
195
  # x: b n d
196
  n = x.shape[-2]
 
198
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
199
  # reshape
200
  q, k, v = map(
201
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
202
+ )
203
  # act
204
  q = self.act(q)
205
  k = self.act(k)
 
207
  q_offset = 0
208
  # lrpe relys on position, get cache first
209
  if past_key_value is not None:
210
+ # reuse k, v, for evaluation only
211
  k = torch.cat([past_key_value[0], k], dim=-2)
212
  v = torch.cat([past_key_value[1], v], dim=-2)
213
  q_offset = past_key_value[0].shape[-2]
 
224
 
225
  if attn_padding_mask is not None:
226
  v = v.masked_fill(
227
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
228
+ )
229
 
230
  if not has_lightning_attention:
231
  if slope_rate != None:
232
  attn_mask = torch.exp(slope_rate * attn_mask)
 
233
  output = linear_attention(q, k, v, attn_mask)
234
  else:
235
+ output = lightning_attention(
236
+ q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
237
+ )
238
 
239
  # reshape
240
  output = rearrange(output, "b h n d -> b n (h d)")
 
253
  return output, attn_weights, past_key_value
254
 
255
  def inference(
256
+ self,
257
+ x,
258
+ attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
259
+ attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
260
+ output_attentions: bool = False,
261
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
262
+ use_cache: bool = False,
263
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
264
  ):
265
  # x: b n d
266
  n = x.shape[-2]
 
268
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
269
  # reshape
270
  q, k, v = map(
271
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
272
+ )
273
  # act
274
  q = self.act(q)
275
  k = self.act(k)
 
277
  # rpe
278
  if self.linear_use_lrpe:
279
  q = self.lrpe(q, offset=self.offset)
280
+ k = self.lrpe(k)
281
 
282
  if past_key_value == None:
283
  self.offset = q.shape[-2]
 
295
 
296
  if attn_padding_mask is not None:
297
  attn_mask = attn_mask.masked_fill(
298
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(2).to(torch.bool),
 
299
  0,
300
  )
301
  energy = torch.einsum("... n d, ... m d -> ... n m", q, k)
 
306
  output = torch.einsum("... n m, ... m d -> ... n d", energy, v)
307
 
308
  eval_and_not_generate = eval(
309
+ os.environ.get("eval_and_not_generate", default="False")
310
+ )
311
  if eval_and_not_generate:
312
  kv = None
313
  else:
314
  # b, h, n, e, d
315
+ kv_outproduct = torch.einsum("... n e, ... n d -> ... n e d", k, v)
 
316
  # 1, 1, n, 1, 1
317
+ index = torch.arange(n - 1, -1, -1).reshape(1, 1, -1, 1, 1).to(x)
 
318
  # (h, 1, 1) -> (1, h, 1, 1, 1); (1, h, 1, 1, 1), (1, 1, n, 1, 1) -> (1, h, n, 1, 1)
319
+ decay = ratio.unsqueeze(0).unsqueeze(-1) ** index
320
 
321
  kv_outproduct_with_decay = kv_outproduct * decay
322
  kv = torch.sum(kv_outproduct_with_decay, dim=-3)
 
327
  for i in range(n):
328
  kv = ratio * kv + torch.einsum(
329
  "... n d, ... n e -> ... d e",
330
+ k[:, :, i : i + 1],
331
+ v[:, :, i : i + 1],
332
+ )
333
+ qkv = torch.einsum(
334
+ "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv
335
  )
 
 
336
  output.append(qkv)
337
  output = torch.concat(output, dim=-2)
338
 
 
351
 
352
 
353
  class TransnormerDecoderLayer(nn.Module):
 
354
  def __init__(self, config: TransnormerConfig):
355
  super().__init__()
356
  self.embed_dim = config.decoder_embed_dim
 
389
  return residual + x
390
 
391
  def forward(
392
+ self,
393
+ x,
394
+ attn_mask: Optional[torch.Tensor] = None,
395
+ attn_padding_mask: Optional[torch.Tensor] = None,
396
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
397
+ output_attentions: Optional[bool] = False,
398
+ use_cache: Optional[bool] = False,
399
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
400
  ):
401
  residual = x
402
  x = self.token_norm(x)
 
416
  x = self.channel_mixer(x)
417
  x = self.residual_connection(x, residual)
418
 
419
+ outputs = (x,)
420
 
421
  if output_attentions:
422
+ outputs += (self_attn_weights,)
423
 
424
  if use_cache:
425
+ outputs += (present_key_value,)
426
 
427
  return outputs
428
 
 
444
  """
445
 
446
 
447
+ @add_start_docstrings(
448
+ TRANSNORMER_START_DOCSTRING,
449
+ )
450
  class TransnormerPreTrainedModel(PreTrainedModel):
451
  config_class = TransnormerConfig
452
  base_model_prefix = "model"
 
531
  """
532
 
533
 
534
+ @add_start_docstrings(
535
+ TRANSNORMER_START_DOCSTRING,
536
+ )
537
  class TransnormerModel(TransnormerPreTrainedModel):
538
  """
539
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
 
557
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
558
 
559
  # params
560
+ self.embed_tokens = nn.Embedding(
561
+ config.vocab_size, config.decoder_embed_dim, self.padding_idx
562
+ )
563
  self.layers = nn.ModuleList([])
564
  for i in range(config.decoder_layers):
565
  if len(self.linear_use_lrpe_list) > 0:
566
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
567
  self.layers.append(TransnormerDecoderLayer(config))
568
 
569
+ self.final_norm = get_norm_fn(config.norm_type)(config.decoder_embed_dim)
 
570
  self.embed_dim = config.decoder_embed_dim
571
+ self.embed_scale = (
572
+ 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
573
+ )
574
 
575
  # Initialize weights and apply final processing
576
  self.post_init()
577
 
578
  @staticmethod
579
  def _build_slope_tensor(n_attention_heads: int):
 
580
  def get_slopes(n):
 
581
  def get_slopes_power_of_2(n):
582
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
583
  ratio = start
584
  return [start * ratio**i for i in range(n)]
585
 
 
588
  n
589
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
590
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
591
+ closest_power_of_2 = 2 ** math.floor(
592
  math.log2(n)
593
  ) # when the number of heads is not a power of 2, we use this workaround.
594
+ return (
595
+ get_slopes_power_of_2(closest_power_of_2)
596
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
597
+ )
598
 
599
  # h, 1, 1
600
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
601
+ n_attention_heads, 1, 1
602
+ )
603
 
604
  return slopes
605
 
 
612
  def set_input_embeddings(self, value):
613
  self.embed_tokens = value
614
 
615
+ def _prepare_decoder_linear_attn_mask(
616
+ self, input_shape, inputs_embeds, past_key_values_length
617
+ ):
618
  bsz, tgt_len = input_shape
619
  src_len = tgt_len + past_key_values_length
620
 
621
  def power_log(x):
622
+ return 2 ** (math.ceil(math.log(x, 2)))
623
 
624
  n = power_log(max(tgt_len, src_len))
625
  if self._linear_attn_mask.shape[-1] < n:
626
 
627
  def get_mask(n):
628
+ mask = torch.triu(torch.zeros(n, n).float().fill_(float("-inf")), 1)
 
629
  # no slope version
630
  # -n, ..., -2, -1, 0
631
  for i in range(n):
632
  x = torch.arange(i + 1)
633
  y = x
634
+ mask[i, : i + 1] = -torch.flip(y, [0])
635
 
636
  return mask
637
 
 
643
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
644
  num_heads = linear_attn_mask.shape[0]
645
 
646
+ return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len, src_len)
 
647
 
648
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
649
  def forward(
 
657
  output_hidden_states: Optional[bool] = None,
658
  return_dict: Optional[bool] = None,
659
  ) -> Union[Tuple, BaseModelOutputWithPast]:
660
+ output_attentions = (
661
+ output_attentions
662
+ if output_attentions is not None
663
+ else self.config.output_attentions
664
+ )
665
+ output_hidden_states = (
666
+ output_hidden_states
667
+ if output_hidden_states is not None
668
+ else self.config.output_hidden_states
669
+ )
670
  use_cache = use_cache if use_cache is not None else self.config.use_cache
671
 
672
+ return_dict = (
673
+ return_dict if return_dict is not None else self.config.use_return_dict
674
+ )
675
 
676
  # retrieve input_ids and inputs_embeds
677
  if input_ids is not None and inputs_embeds is not None:
 
693
  if past_key_values is not None:
694
  past_key_values_length = past_key_values[0][0].shape[-2]
695
  seq_length_with_past = seq_length_with_past + past_key_values_length
696
+
697
  if inputs_embeds is None:
698
  # !!! use embed_scale
699
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
 
715
  ##### norm linear layers
716
  linear_attn_padding_mask = attn_padding_mask
717
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
718
+ (batch_size, seq_length), inputs_embeds, past_key_values_length
719
+ )
720
 
721
+ slope_rates = [self.slopes.to(input_ids.device) for _ in range(self.num_layers)]
 
 
722
 
723
  for idx, layer in enumerate(self.layers):
724
  if output_hidden_states:
725
+ all_hidden_states += (hidden_states,)
726
 
727
+ past_key_value = (
728
+ past_key_values[idx] if past_key_values is not None else None
729
+ )
730
 
731
  slope_rate = slope_rates[idx]
732
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
733
  mask = linear_attn_mask
734
+
735
+ layer_outputs = layer(
736
+ hidden_states,
737
+ attn_mask=mask,
738
+ attn_padding_mask=linear_attn_padding_mask,
739
+ past_key_value=past_key_value,
740
+ output_attentions=output_attentions,
741
+ use_cache=use_cache,
742
+ slope_rate=slope_rate,
743
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
 
745
  hidden_states = layer_outputs[0]
746
 
747
  if use_cache:
748
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 
749
 
750
  if output_attentions:
751
+ all_self_attns += (layer_outputs[1],)
752
 
753
  hidden_states = self.final_norm(hidden_states)
754
 
755
  # add hidden states from the last decoder layer
756
  if output_hidden_states:
757
+ all_hidden_states += (hidden_states,)
758
 
759
  next_cache = next_decoder_cache if use_cache else None
760
  if not return_dict:
761
  return tuple(
762
+ v
763
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
764
+ if v is not None
765
+ )
766
  return BaseModelOutputWithPast(
767
  last_hidden_state=hidden_states,
768
  past_key_values=next_cache,
 
772
 
773
 
774
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
 
775
  def __init__(self, config):
776
  super().__init__(config)
777
  self.model = TransnormerModel(config)
 
779
  logging_info(self.model)
780
 
781
  # the lm_head weight is automatically tied to the embed tokens weight
782
+ self.lm_head = nn.Linear(
783
+ config.decoder_embed_dim, config.vocab_size, bias=False
784
+ )
785
 
786
  # Initialize weights and apply final processing
787
  self.post_init()
 
805
  return self.model
806
 
807
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
808
+ @replace_return_docstrings(
809
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
810
+ )
811
  def forward(
812
  self,
813
  input_ids: torch.LongTensor = None,
 
845
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
846
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
847
  ```"""
848
+ output_attentions = (
849
+ output_attentions
850
+ if output_attentions is not None
851
+ else self.config.output_attentions
852
+ )
853
+ output_hidden_states = (
854
+ output_hidden_states
855
+ if output_hidden_states is not None
856
+ else self.config.output_hidden_states
857
+ )
858
+ return_dict = (
859
+ return_dict if return_dict is not None else self.config.use_return_dict
860
+ )
861
 
862
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
863
  outputs = self.model(
 
888
  loss = loss_fct(shift_logits, shift_labels)
889
 
890
  if not return_dict:
891
+ output = (logits,) + outputs[1:]
892
+ return (loss,) + output if loss is not None else output
893
 
894
  return CausalLMOutputWithPast(
895
  loss=loss,
 
916
  else:
917
  model_inputs = {"input_ids": input_ids}
918
 
919
+ model_inputs.update(
920
+ {
921
+ "past_key_values": past_key_values,
922
+ "use_cache": kwargs.get("use_cache"),
923
+ "attention_mask": attention_mask,
924
+ }
925
+ )
926
  return model_inputs
927
 
928
  @staticmethod
929
  def _reorder_cache(past_key_values, beam_idx):
930
  reordered_past = ()
931
  for layer_past in past_key_values:
932
+ reordered_past += (
933
+ tuple(
934
+ past_state.index_select(0, beam_idx) for past_state in layer_past
935
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
  )
937
+ return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
 
 
 
 
 
 
 
 
tokenization_baichuan.py CHANGED
@@ -73,6 +73,11 @@ class BaiChuanTokenizer(PreTrainedTokenizer):
73
  if isinstance(unk_token, str) else unk_token)
74
  pad_token = (AddedToken(pad_token, lstrip=False, rstrip=False)
75
  if isinstance(pad_token, str) else pad_token)
 
 
 
 
 
76
  super().__init__(
77
  bos_token=bos_token,
78
  eos_token=eos_token,
@@ -84,11 +89,6 @@ class BaiChuanTokenizer(PreTrainedTokenizer):
84
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
85
  **kwargs,
86
  )
87
- self.vocab_file = vocab_file
88
- self.add_bos_token = add_bos_token
89
- self.add_eos_token = add_eos_token
90
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
91
- self.sp_model.Load(vocab_file)
92
 
93
  def __getstate__(self):
94
  state = self.__dict__.copy()
 
73
  if isinstance(unk_token, str) else unk_token)
74
  pad_token = (AddedToken(pad_token, lstrip=False, rstrip=False)
75
  if isinstance(pad_token, str) else pad_token)
76
+ self.vocab_file = vocab_file
77
+ self.add_bos_token = add_bos_token
78
+ self.add_eos_token = add_eos_token
79
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
80
+ self.sp_model.Load(vocab_file)
81
  super().__init__(
82
  bos_token=bos_token,
83
  eos_token=eos_token,
 
89
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
90
  **kwargs,
91
  )
 
 
 
 
 
92
 
93
  def __getstate__(self):
94
  state = self.__dict__.copy()