feat: reverted monkey patch
Browse files- configuration_bert.py +0 -2
- modeling_bert.py +5 -17
configuration_bert.py
CHANGED
@@ -14,8 +14,6 @@
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
""" BERT model configuration"""
|
17 |
-
from collections import OrderedDict
|
18 |
-
from typing import Mapping
|
19 |
|
20 |
from transformers import PretrainedConfig
|
21 |
|
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
""" BERT model configuration"""
|
|
|
|
|
17 |
|
18 |
from transformers import PretrainedConfig
|
19 |
|
modeling_bert.py
CHANGED
@@ -28,16 +28,13 @@ from transformers.models.bert.modeling_bert import (
|
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
29 |
BertForPreTrainingOutput,
|
30 |
)
|
31 |
-
from .patched_padding_bert import index_first_axis as index_first_axis_monkey_patch
|
32 |
-
import flash_attn.bert_padding
|
33 |
-
flash_attn.bert_padding.index_first_axis = index_first_axis_monkey_patch
|
34 |
-
"""
|
35 |
from flash_attn.bert_padding import (
|
|
|
36 |
index_first_axis_residual,
|
37 |
pad_input,
|
38 |
unpad_input,
|
39 |
)
|
40 |
-
|
41 |
from flash_attn.modules.block import Block
|
42 |
from flash_attn.modules.embedding import BertEmbeddings
|
43 |
from flash_attn.modules.mha import MHA
|
@@ -176,14 +173,14 @@ class BertEncoder(nn.Module):
|
|
176 |
hidden_states = hidden_states[subset_mask]
|
177 |
else:
|
178 |
batch, seqlen = hidden_states.shape[:2]
|
179 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch =
|
180 |
hidden_states, key_padding_mask
|
181 |
)
|
182 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
183 |
if subset_mask is None:
|
184 |
for layer in self.layers:
|
185 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
186 |
-
hidden_states =
|
187 |
else:
|
188 |
for layer in self.layers[:-1]:
|
189 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
@@ -201,7 +198,7 @@ class BertEncoder(nn.Module):
|
|
201 |
subset_cu_seqlens = F.pad(
|
202 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
203 |
)
|
204 |
-
hidden_states_subset, hidden_states =
|
205 |
hidden_states, subset_idx
|
206 |
)
|
207 |
# It's ok to set max_seqlen_q to be much larger
|
@@ -425,15 +422,6 @@ class BertModel(BertPreTrainedModel):
|
|
425 |
pooler_output=pooled_output,
|
426 |
)
|
427 |
|
428 |
-
def to(self, *args, **kwargs):
|
429 |
-
print(f'In BERT, calling to({args, kwargs})')
|
430 |
-
result = super().to(*args, **kwargs)
|
431 |
-
if (len(args) > 0 and isinstance(args[0], torch.dtype)) or "dtype" in kwargs:
|
432 |
-
for layer in result.encoder.layers:
|
433 |
-
layer.mixer.inner_cross_attn.alibi_slopes = layer.mixer.inner_cross_attn.alibi_slopes.to(torch.float32)
|
434 |
-
layer.mixer.inner_attn.alibi_slopes = layer.mixer.inner_attn.alibi_slopes.to(torch.float32)
|
435 |
-
return result
|
436 |
-
|
437 |
|
438 |
class BertForPreTraining(BertPreTrainedModel):
|
439 |
def __init__(self, config: JinaBertConfig):
|
|
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
29 |
BertForPreTrainingOutput,
|
30 |
)
|
|
|
|
|
|
|
|
|
31 |
from flash_attn.bert_padding import (
|
32 |
+
index_first_axis,
|
33 |
index_first_axis_residual,
|
34 |
pad_input,
|
35 |
unpad_input,
|
36 |
)
|
37 |
+
|
38 |
from flash_attn.modules.block import Block
|
39 |
from flash_attn.modules.embedding import BertEmbeddings
|
40 |
from flash_attn.modules.mha import MHA
|
|
|
173 |
hidden_states = hidden_states[subset_mask]
|
174 |
else:
|
175 |
batch, seqlen = hidden_states.shape[:2]
|
176 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
177 |
hidden_states, key_padding_mask
|
178 |
)
|
179 |
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
180 |
if subset_mask is None:
|
181 |
for layer in self.layers:
|
182 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
183 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
184 |
else:
|
185 |
for layer in self.layers[:-1]:
|
186 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
198 |
subset_cu_seqlens = F.pad(
|
199 |
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
200 |
)
|
201 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
202 |
hidden_states, subset_idx
|
203 |
)
|
204 |
# It's ok to set max_seqlen_q to be much larger
|
|
|
422 |
pooler_output=pooled_output,
|
423 |
)
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
class BertForPreTraining(BertPreTrainedModel):
|
427 |
def __init__(self, config: JinaBertConfig):
|