feat: added back option not to use flash attention
Browse files- configuration_bert.py +2 -0
- modeling_bert.py +4 -2
configuration_bert.py
CHANGED
@@ -81,6 +81,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
|
|
84 |
**kwargs,
|
85 |
):
|
86 |
assert 'position_embedding_type' not in kwargs
|
@@ -106,3 +107,4 @@ class JinaBertConfig(PretrainedConfig):
|
|
106 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
107 |
self.fused_bias_fc = fused_bias_fc
|
108 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
|
|
81 |
fused_dropout_add_ln=False,
|
82 |
fused_bias_fc=False,
|
83 |
pad_vocab_size_multiple=1,
|
84 |
+
use_flash_attn=True,
|
85 |
**kwargs,
|
86 |
):
|
87 |
assert 'position_embedding_type' not in kwargs
|
|
|
107 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
108 |
self.fused_bias_fc = fused_bias_fc
|
109 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
110 |
+
self.use_flash_attn = use_flash_attn
|
modeling_bert.py
CHANGED
@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
|
|
59 |
|
60 |
|
61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
|
62 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
63 |
window_size = getattr(config, "window_size", (-1, -1))
|
64 |
mixer_cls = partial(
|
@@ -68,7 +69,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
68 |
dropout=config.attention_probs_dropout_prob,
|
69 |
causal=False,
|
70 |
fused_bias_fc=fused_bias_fc,
|
71 |
-
use_flash_attn=
|
72 |
return_residual=return_residual,
|
73 |
use_alibi=True,
|
74 |
window_size=window_size,
|
@@ -151,6 +152,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
151 |
class BertEncoder(nn.Module):
|
152 |
def __init__(self, config: JinaBertConfig):
|
153 |
super().__init__()
|
|
|
154 |
self.layers = nn.ModuleList(
|
155 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
156 |
)
|
@@ -171,7 +173,7 @@ class BertEncoder(nn.Module):
|
|
171 |
This means that we only compute the last layer output for these tokens.
|
172 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
173 |
"""
|
174 |
-
if key_padding_mask is None:
|
175 |
mixer_kwargs = (
|
176 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
177 |
)
|
|
|
59 |
|
60 |
|
61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
62 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
63 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
64 |
window_size = getattr(config, "window_size", (-1, -1))
|
65 |
mixer_cls = partial(
|
|
|
69 |
dropout=config.attention_probs_dropout_prob,
|
70 |
causal=False,
|
71 |
fused_bias_fc=fused_bias_fc,
|
72 |
+
use_flash_attn=use_flash_attn,
|
73 |
return_residual=return_residual,
|
74 |
use_alibi=True,
|
75 |
window_size=window_size,
|
|
|
152 |
class BertEncoder(nn.Module):
|
153 |
def __init__(self, config: JinaBertConfig):
|
154 |
super().__init__()
|
155 |
+
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
156 |
self.layers = nn.ModuleList(
|
157 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
158 |
)
|
|
|
173 |
This means that we only compute the last layer output for these tokens.
|
174 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
175 |
"""
|
176 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
177 |
mixer_kwargs = (
|
178 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
179 |
)
|