Tom Aarsen
commited on
Commit
•
0d894e3
1
Parent(s):
66ea5d6
Allow loading via AutoModelForSequenceClassification
Browse files- bert_layers.py +9 -0
bert_layers.py
CHANGED
@@ -29,6 +29,7 @@ from transformers.modeling_outputs import (MaskedLMOutput,
|
|
29 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
30 |
|
31 |
from .blockdiag_linear import BlockdiagLinear
|
|
|
32 |
from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
|
33 |
|
34 |
logger = logging.getLogger(__name__)
|
@@ -475,6 +476,8 @@ class BertModel(BertPreTrainedModel):
|
|
475 |
```
|
476 |
"""
|
477 |
|
|
|
|
|
478 |
def __init__(self, config, add_pooling_layer=True):
|
479 |
super(BertModel, self).__init__(config)
|
480 |
self.embeddings = BertEmbeddings(config)
|
@@ -602,6 +605,8 @@ class BertOnlyNSPHead(nn.Module):
|
|
602 |
#######################
|
603 |
class BertForMaskedLM(BertPreTrainedModel):
|
604 |
|
|
|
|
|
605 |
def __init__(self, config):
|
606 |
super().__init__(config)
|
607 |
|
@@ -748,6 +753,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
748 |
e.g., GLUE tasks.
|
749 |
"""
|
750 |
|
|
|
|
|
751 |
def __init__(self, config):
|
752 |
super().__init__(config)
|
753 |
self.num_labels = config.num_labels
|
@@ -873,6 +880,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
873 |
|
874 |
class BertForTextEncoding(BertPreTrainedModel):
|
875 |
|
|
|
|
|
876 |
def __init__(self, config):
|
877 |
super().__init__(config)
|
878 |
|
|
|
29 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
30 |
|
31 |
from .blockdiag_linear import BlockdiagLinear
|
32 |
+
from .configuration_bert import BertConfig
|
33 |
from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
|
34 |
|
35 |
logger = logging.getLogger(__name__)
|
|
|
476 |
```
|
477 |
"""
|
478 |
|
479 |
+
config_class = BertConfig
|
480 |
+
|
481 |
def __init__(self, config, add_pooling_layer=True):
|
482 |
super(BertModel, self).__init__(config)
|
483 |
self.embeddings = BertEmbeddings(config)
|
|
|
605 |
#######################
|
606 |
class BertForMaskedLM(BertPreTrainedModel):
|
607 |
|
608 |
+
config_class = BertConfig
|
609 |
+
|
610 |
def __init__(self, config):
|
611 |
super().__init__(config)
|
612 |
|
|
|
753 |
e.g., GLUE tasks.
|
754 |
"""
|
755 |
|
756 |
+
config_class = BertConfig
|
757 |
+
|
758 |
def __init__(self, config):
|
759 |
super().__init__(config)
|
760 |
self.num_labels = config.num_labels
|
|
|
880 |
|
881 |
class BertForTextEncoding(BertPreTrainedModel):
|
882 |
|
883 |
+
config_class = BertConfig
|
884 |
+
|
885 |
def __init__(self, config):
|
886 |
super().__init__(config)
|
887 |
|