Markus28 commited on
Commit
bb281f0
1 Parent(s): 18eed80

feat: added get_input_embeddings method to BertForPreTraining

Browse files
Files changed (1) hide show
  1. modeling_bert.py +3 -0
modeling_bert.py CHANGED
@@ -459,6 +459,9 @@ class BertForPreTraining(BertPreTrainedModel):
459
  def tie_weights(self):
460
  self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
461
 
 
 
 
462
  def forward(
463
  self,
464
  input_ids,
 
459
  def tie_weights(self):
460
  self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
461
 
462
+ def get_input_embeddings(self):
463
+ return self.embeddings.word_embeddings
464
+
465
  def forward(
466
  self,
467
  input_ids,