feat: added get_input_embeddings method to BertForPreTraining
Browse files- modeling_bert.py +3 -0
modeling_bert.py
CHANGED
@@ -459,6 +459,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
459 |
def tie_weights(self):
|
460 |
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
461 |
|
|
|
|
|
|
|
462 |
def forward(
|
463 |
self,
|
464 |
input_ids,
|
|
|
459 |
def tie_weights(self):
|
460 |
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
461 |
|
462 |
+
def get_input_embeddings(self):
|
463 |
+
return self.embeddings.word_embeddings
|
464 |
+
|
465 |
def forward(
|
466 |
self,
|
467 |
input_ids,
|