fix BertForMaskedLM
Browse files- modeling_bert.py +8 -8
modeling_bert.py
CHANGED
@@ -752,18 +752,18 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
752 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
753 |
|
754 |
if (
|
755 |
-
|
756 |
): # prediction_scores are already flattened
|
757 |
masked_lm_loss = self.mlm_loss(
|
758 |
prediction_scores, labels.flatten()[masked_token_idx]
|
759 |
).float()
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
|
768 |
return BertForPreTrainingOutput(
|
769 |
loss=masked_lm_loss,
|
|
|
752 |
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
753 |
|
754 |
if (
|
755 |
+
self.dense_seq_output and labels is not None
|
756 |
): # prediction_scores are already flattened
|
757 |
masked_lm_loss = self.mlm_loss(
|
758 |
prediction_scores, labels.flatten()[masked_token_idx]
|
759 |
).float()
|
760 |
+
elif labels is not None:
|
761 |
+
masked_lm_loss = self.mlm_loss(
|
762 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
763 |
+
rearrange(labels, "... -> (...)"),
|
764 |
+
).float()
|
765 |
+
else:
|
766 |
+
raise ValueError('MLM labels must not be None')
|
767 |
|
768 |
return BertForPreTrainingOutput(
|
769 |
loss=masked_lm_loss,
|