feat: updated modeling_bert.py to allow MLM-only training
Browse files- modeling_bert.py +19 -15
modeling_bert.py
CHANGED
@@ -494,24 +494,28 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
494 |
)
|
495 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
496 |
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
510 |
next_sentence_loss = self.nsp_loss(
|
511 |
rearrange(seq_relationship_score, "... t -> (...) t"),
|
512 |
rearrange(next_sentence_label, "... -> (...)"),
|
513 |
-
)
|
514 |
-
|
|
|
|
|
|
|
515 |
|
516 |
return BertForPreTrainingOutput(
|
517 |
loss=total_loss,
|
|
|
494 |
)
|
495 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
496 |
|
497 |
+
if (
|
498 |
+
self.dense_seq_output and labels is not None
|
499 |
+
): # prediction_scores are already flattened
|
500 |
+
masked_lm_loss = self.mlm_loss(
|
501 |
+
prediction_scores, labels.flatten()[masked_token_idx]
|
502 |
+
).float()
|
503 |
+
elif labels is not None:
|
504 |
+
masked_lm_loss = self.mlm_loss(
|
505 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
506 |
+
rearrange(labels, "... -> (...)"),
|
507 |
+
).float()
|
508 |
+
else:
|
509 |
+
masked_lm_loss = 0
|
510 |
+
if next_sentence_label is not None:
|
511 |
next_sentence_loss = self.nsp_loss(
|
512 |
rearrange(seq_relationship_score, "... t -> (...) t"),
|
513 |
rearrange(next_sentence_label, "... -> (...)"),
|
514 |
+
).float()
|
515 |
+
else:
|
516 |
+
next_sentence_loss = 0
|
517 |
+
|
518 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
519 |
|
520 |
return BertForPreTrainingOutput(
|
521 |
loss=total_loss,
|