smg478 commited on
Commit
c2ced98
1 Parent(s): 7383c55

fix flash_attn import logic to comply with huggingface import check

Browse files
Files changed (1) hide show
  1. modeling_intern_vit.py +9 -11
modeling_intern_vit.py CHANGED
@@ -19,20 +19,18 @@ from transformers.utils import logging
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
 
22
  try:
23
- try: # v1
24
- from flash_attn.flash_attn_interface import \
25
- flash_attn_unpadded_qkvpacked_func
26
- except: # v2
27
- from flash_attn.flash_attn_interface import \
28
- flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
29
-
30
  from flash_attn.bert_padding import pad_input, unpad_input
31
-
32
  has_flash_attn = True
33
- except:
34
- print('FlashAttention is not installed.')
35
- has_flash_attn = False
 
 
 
 
36
 
37
  logger = logging.get_logger(__name__)
38
 
 
19
 
20
  from .configuration_intern_vit import InternVisionConfig
21
 
22
+ has_flash_attn = False
23
  try:
24
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
 
 
 
 
 
 
25
  from flash_attn.bert_padding import pad_input, unpad_input
 
26
  has_flash_attn = True
27
+ except ImportError:
28
+ try:
29
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
30
+ from flash_attn.bert_padding import pad_input, unpad_input
31
+ has_flash_attn = True
32
+ except ImportError:
33
+ print('FlashAttention is not installed.')
34
 
35
  logger = logging.get_logger(__name__)
36