ccdv commited on
Commit
e1c052c
1 Parent(s): c870da8
Files changed (2) hide show
  1. README.md +2 -2
  2. modeling_lsg_bart.py +11 -5
README.md CHANGED
@@ -18,7 +18,7 @@ model-index:
18
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
19
  should probably proofread and complete it, then remove this comment. -->
20
 
21
- **Transformers >= 4.35.2**\
22
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
23
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
24
 
@@ -105,7 +105,7 @@ The following hyperparameters were used during generation:
105
 
106
  ### Framework versions
107
 
108
- - Transformers 4.35.2
109
  - Pytorch 1.12.1
110
  - Datasets 2.3.2
111
  - Tokenizers 0.11.6
 
18
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
19
  should probably proofread and complete it, then remove this comment. -->
20
 
21
+ **Transformers >= 4.36.1**\
22
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
23
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
24
 
 
105
 
106
  ### Framework versions
107
 
108
+ - Transformers 4.36.1
109
  - Pytorch 1.12.1
110
  - Datasets 2.3.2
111
  - Tokenizers 0.11.6
modeling_lsg_bart.py CHANGED
@@ -828,17 +828,17 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
828
  if input_ids is not None and inputs_embeds is not None:
829
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
830
  elif input_ids is not None:
831
- input_shape = input_ids.size()
832
- input_ids = input_ids.view(-1, input_shape[-1])
833
  elif inputs_embeds is not None:
834
- input_shape = inputs_embeds.size()[:-1]
835
  else:
836
  raise ValueError("You have to specify either input_ids or inputs_embeds")
837
 
838
  if inputs_embeds is None:
839
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
840
 
841
- embed_pos = self.embed_positions(inputs_embeds)
842
  hidden_states = inputs_embeds + embed_pos
843
 
844
  # Add global tokens
@@ -931,6 +931,12 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
931
  self.encoder = LSGBartEncoder(config, self.shared)
932
  self.decoder = BartDecoder(config, self.shared)
933
 
 
 
 
 
 
 
934
  # Initialize weights and apply final processing
935
  self.post_init()
936
 
@@ -1093,4 +1099,4 @@ try:
1093
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1094
  except:
1095
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1096
- warn("Update to transformers >= 4.35.2 to fix.")
 
828
  if input_ids is not None and inputs_embeds is not None:
829
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
830
  elif input_ids is not None:
831
+ input = input_ids
832
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
833
  elif inputs_embeds is not None:
834
+ input = inputs_embeds[:, :, -1]
835
  else:
836
  raise ValueError("You have to specify either input_ids or inputs_embeds")
837
 
838
  if inputs_embeds is None:
839
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
840
 
841
+ embed_pos = self.embed_positions(input).to(inputs_embeds.device)
842
  hidden_states = inputs_embeds + embed_pos
843
 
844
  # Add global tokens
 
931
  self.encoder = LSGBartEncoder(config, self.shared)
932
  self.decoder = BartDecoder(config, self.shared)
933
 
934
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
935
+ if self._use_flash_attention_2:
936
+ logger.warning(
937
+ "[WARNING flash-attention]: LSG doesnt support flash-attention currently"
938
+ )
939
+
940
  # Initialize weights and apply final processing
941
  self.post_init()
942
 
 
1099
  str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1100
  except:
1101
  warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1102
+ warn("Update to transformers >= 4.36.1 to fix.")