feat: added top-level docstring, made it compatible with AutoModel
Browse files- modeling_bert.py +10 -5
modeling_bert.py
CHANGED
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Copyright (c) 2022, Tri Dao.
|
2 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
@@ -297,12 +304,10 @@ class BertPreTrainedModel(nn.Module):
|
|
297 |
|
298 |
def __init__(self, config, *inputs, **kwargs):
|
299 |
super().__init__()
|
300 |
-
if not
|
301 |
raise ValueError(
|
302 |
-
"Parameter config in `{}(config)` should be an instance of class `
|
303 |
-
|
304 |
-
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
305 |
-
self.__class__.__name__, self.__class__.__name__
|
306 |
)
|
307 |
)
|
308 |
self.config = config
|
|
|
1 |
+
""" Implementation of BERT, using ALiBi and Flash Attention
|
2 |
+
|
3 |
+
The implementation was adopted from
|
4 |
+
https://github.com/Dao-AILab/flash-attention/blob/43950dda456e095969d842fca7a73c5bfe3cecd0/flash_attn/models/bert.py
|
5 |
+
and made modifications to use ALiBi.
|
6 |
+
"""
|
7 |
+
|
8 |
# Copyright (c) 2022, Tri Dao.
|
9 |
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
10 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
|
|
304 |
|
305 |
def __init__(self, config, *inputs, **kwargs):
|
306 |
super().__init__()
|
307 |
+
if not config.__class__.__name__ == 'JinaBertConfig':
|
308 |
raise ValueError(
|
309 |
+
"Parameter config in `{}(config)` should be an instance of class `JinaBertConfig`.".format(
|
310 |
+
self.__class__.__name__,
|
|
|
|
|
311 |
)
|
312 |
)
|
313 |
self.config = config
|