fix flash_attn import logic to comply with huggingface import check
Browse files- modeling_intern_vit.py +9 -11
modeling_intern_vit.py
CHANGED
@@ -19,20 +19,18 @@ from transformers.utils import logging
|
|
19 |
|
20 |
from .configuration_intern_vit import InternVisionConfig
|
21 |
|
|
|
22 |
try:
|
23 |
-
|
24 |
-
from flash_attn.flash_attn_interface import \
|
25 |
-
flash_attn_unpadded_qkvpacked_func
|
26 |
-
except: # v2
|
27 |
-
from flash_attn.flash_attn_interface import \
|
28 |
-
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
29 |
-
|
30 |
from flash_attn.bert_padding import pad_input, unpad_input
|
31 |
-
|
32 |
has_flash_attn = True
|
33 |
-
except:
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
|
|
19 |
|
20 |
from .configuration_intern_vit import InternVisionConfig
|
21 |
|
22 |
+
has_flash_attn = False
|
23 |
try:
|
24 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
from flash_attn.bert_padding import pad_input, unpad_input
|
|
|
26 |
has_flash_attn = True
|
27 |
+
except ImportError:
|
28 |
+
try:
|
29 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
30 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
31 |
+
has_flash_attn = True
|
32 |
+
except ImportError:
|
33 |
+
print('FlashAttention is not installed.')
|
34 |
|
35 |
logger = logging.get_logger(__name__)
|
36 |
|