lyogavin commited on
Commit
1083953
1 Parent(s): fc4424c

fix default value of xentropy

Browse files
Files changed (1) hide show
  1. modeling_flash_llama.py +11 -3
modeling_flash_llama.py CHANGED
@@ -44,11 +44,19 @@ try:
44
  from flash_attn.bert_padding import unpad_input, pad_input
45
  flash_attn_v2_installed = True
46
  print('>>>> Flash Attention installed')
47
- from flash_attn.losses.cross_entropy import CrossEntropyLoss as xCrossEntropyLoss
48
  except ImportError:
49
  flash_attn_v2_installed = False
50
  raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
51
 
 
 
 
 
 
 
 
 
 
52
  try:
53
  from flash_attn.layers.rotary import apply_rotary_emb_func
54
  flash_rope_installed = True
@@ -774,7 +782,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
774
  output_hidden_states: Optional[bool] = None,
775
  return_dict: Optional[bool] = None,
776
  only_last_logit: Optional[bool] = None,
777
- xentropy: Optional[bool] = None,
778
  is_padded_inputs: Optional[bool] = None,
779
  ) -> Union[Tuple, CausalLMOutputWithPast]:
780
  r"""
@@ -869,7 +877,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
869
 
870
  def prepare_inputs_for_generation(
871
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, only_last_logit=False,
872
- xentropy=False, **kwargs
873
  ):
874
  if past_key_values:
875
  input_ids = input_ids[:, -1:]
 
44
  from flash_attn.bert_padding import unpad_input, pad_input
45
  flash_attn_v2_installed = True
46
  print('>>>> Flash Attention installed')
 
47
  except ImportError:
48
  flash_attn_v2_installed = False
49
  raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
50
 
51
+ try:
52
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss as xCrossEntropyLoss
53
+ flash_xentropy_installed = True
54
+ print('>>>> xentropy installed')
55
+ except ImportError:
56
+ flash_xentropy_installed = False
57
+ raise ImportError('Please install xentropy kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/xentropy`')
58
+
59
+
60
  try:
61
  from flash_attn.layers.rotary import apply_rotary_emb_func
62
  flash_rope_installed = True
 
782
  output_hidden_states: Optional[bool] = None,
783
  return_dict: Optional[bool] = None,
784
  only_last_logit: Optional[bool] = None,
785
+ xentropy: Optional[bool] = False,
786
  is_padded_inputs: Optional[bool] = None,
787
  ) -> Union[Tuple, CausalLMOutputWithPast]:
788
  r"""
 
877
 
878
  def prepare_inputs_for_generation(
879
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, only_last_logit=False,
880
+ xentropy=True, **kwargs
881
  ):
882
  if past_key_values:
883
  input_ids = input_ids[:, -1:]