fix default value of xentropy
Browse files- modeling_flash_llama.py +11 -3
modeling_flash_llama.py
CHANGED
@@ -44,11 +44,19 @@ try:
|
|
44 |
from flash_attn.bert_padding import unpad_input, pad_input
|
45 |
flash_attn_v2_installed = True
|
46 |
print('>>>> Flash Attention installed')
|
47 |
-
from flash_attn.losses.cross_entropy import CrossEntropyLoss as xCrossEntropyLoss
|
48 |
except ImportError:
|
49 |
flash_attn_v2_installed = False
|
50 |
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
try:
|
53 |
from flash_attn.layers.rotary import apply_rotary_emb_func
|
54 |
flash_rope_installed = True
|
@@ -774,7 +782,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
774 |
output_hidden_states: Optional[bool] = None,
|
775 |
return_dict: Optional[bool] = None,
|
776 |
only_last_logit: Optional[bool] = None,
|
777 |
-
xentropy: Optional[bool] =
|
778 |
is_padded_inputs: Optional[bool] = None,
|
779 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
780 |
r"""
|
@@ -869,7 +877,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
869 |
|
870 |
def prepare_inputs_for_generation(
|
871 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, only_last_logit=False,
|
872 |
-
xentropy=
|
873 |
):
|
874 |
if past_key_values:
|
875 |
input_ids = input_ids[:, -1:]
|
|
|
44 |
from flash_attn.bert_padding import unpad_input, pad_input
|
45 |
flash_attn_v2_installed = True
|
46 |
print('>>>> Flash Attention installed')
|
|
|
47 |
except ImportError:
|
48 |
flash_attn_v2_installed = False
|
49 |
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
|
50 |
|
51 |
+
try:
|
52 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss as xCrossEntropyLoss
|
53 |
+
flash_xentropy_installed = True
|
54 |
+
print('>>>> xentropy installed')
|
55 |
+
except ImportError:
|
56 |
+
flash_xentropy_installed = False
|
57 |
+
raise ImportError('Please install xentropy kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/xentropy`')
|
58 |
+
|
59 |
+
|
60 |
try:
|
61 |
from flash_attn.layers.rotary import apply_rotary_emb_func
|
62 |
flash_rope_installed = True
|
|
|
782 |
output_hidden_states: Optional[bool] = None,
|
783 |
return_dict: Optional[bool] = None,
|
784 |
only_last_logit: Optional[bool] = None,
|
785 |
+
xentropy: Optional[bool] = False,
|
786 |
is_padded_inputs: Optional[bool] = None,
|
787 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
788 |
r"""
|
|
|
877 |
|
878 |
def prepare_inputs_for_generation(
|
879 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, only_last_logit=False,
|
880 |
+
xentropy=True, **kwargs
|
881 |
):
|
882 |
if past_key_values:
|
883 |
input_ids = input_ids[:, -1:]
|