small fix
Browse files- README.md +2 -2
- modeling_lsg_bart.py +11 -5
README.md
CHANGED
@@ -18,7 +18,7 @@ model-index:
|
|
18 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
19 |
should probably proofread and complete it, then remove this comment. -->
|
20 |
|
21 |
-
**Transformers >= 4.
|
22 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
23 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
24 |
|
@@ -105,7 +105,7 @@ The following hyperparameters were used during generation:
|
|
105 |
|
106 |
### Framework versions
|
107 |
|
108 |
-
- Transformers 4.
|
109 |
- Pytorch 1.12.1
|
110 |
- Datasets 2.3.2
|
111 |
- Tokenizers 0.11.6
|
|
|
18 |
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
19 |
should probably proofread and complete it, then remove this comment. -->
|
20 |
|
21 |
+
**Transformers >= 4.36.1**\
|
22 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
23 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
24 |
|
|
|
105 |
|
106 |
### Framework versions
|
107 |
|
108 |
+
- Transformers 4.36.1
|
109 |
- Pytorch 1.12.1
|
110 |
- Datasets 2.3.2
|
111 |
- Tokenizers 0.11.6
|
modeling_lsg_bart.py
CHANGED
@@ -828,17 +828,17 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
828 |
if input_ids is not None and inputs_embeds is not None:
|
829 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
830 |
elif input_ids is not None:
|
831 |
-
|
832 |
-
input_ids = input_ids.view(-1,
|
833 |
elif inputs_embeds is not None:
|
834 |
-
|
835 |
else:
|
836 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
837 |
|
838 |
if inputs_embeds is None:
|
839 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
840 |
|
841 |
-
embed_pos = self.embed_positions(inputs_embeds)
|
842 |
hidden_states = inputs_embeds + embed_pos
|
843 |
|
844 |
# Add global tokens
|
@@ -931,6 +931,12 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
931 |
self.encoder = LSGBartEncoder(config, self.shared)
|
932 |
self.decoder = BartDecoder(config, self.shared)
|
933 |
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
# Initialize weights and apply final processing
|
935 |
self.post_init()
|
936 |
|
@@ -1093,4 +1099,4 @@ try:
|
|
1093 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1094 |
except:
|
1095 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1096 |
-
warn("Update to transformers >= 4.
|
|
|
828 |
if input_ids is not None and inputs_embeds is not None:
|
829 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
830 |
elif input_ids is not None:
|
831 |
+
input = input_ids
|
832 |
+
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
833 |
elif inputs_embeds is not None:
|
834 |
+
input = inputs_embeds[:, :, -1]
|
835 |
else:
|
836 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
837 |
|
838 |
if inputs_embeds is None:
|
839 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
840 |
|
841 |
+
embed_pos = self.embed_positions(input).to(inputs_embeds.device)
|
842 |
hidden_states = inputs_embeds + embed_pos
|
843 |
|
844 |
# Add global tokens
|
|
|
931 |
self.encoder = LSGBartEncoder(config, self.shared)
|
932 |
self.decoder = BartDecoder(config, self.shared)
|
933 |
|
934 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
935 |
+
if self._use_flash_attention_2:
|
936 |
+
logger.warning(
|
937 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
938 |
+
)
|
939 |
+
|
940 |
# Initialize weights and apply final processing
|
941 |
self.post_init()
|
942 |
|
|
|
1099 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
1100 |
except:
|
1101 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
1102 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|