feat: added separate BertForMaskedLM class
Browse files- modeling_bert.py +80 -0
modeling_bert.py
CHANGED
@@ -689,4 +689,84 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
689 |
loss=total_loss,
|
690 |
prediction_logits=prediction_scores,
|
691 |
seq_relationship_logits=seq_relationship_score,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
)
|
|
|
689 |
loss=total_loss,
|
690 |
prediction_logits=prediction_scores,
|
691 |
seq_relationship_logits=seq_relationship_score,
|
692 |
+
)
|
693 |
+
|
694 |
+
|
695 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
696 |
+
def __init__(self, config: JinaBertConfig):
|
697 |
+
super().__init__(config)
|
698 |
+
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
699 |
+
# (around 15%) to the classifier heads.
|
700 |
+
self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
701 |
+
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
702 |
+
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
703 |
+
self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
704 |
+
if self.last_layer_subset:
|
705 |
+
assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
706 |
+
use_xentropy = getattr(config, "use_xentropy", False)
|
707 |
+
if use_xentropy and CrossEntropyLoss is None:
|
708 |
+
raise ImportError("xentropy_cuda is not installed")
|
709 |
+
loss_cls = (
|
710 |
+
nn.CrossEntropyLoss
|
711 |
+
if not use_xentropy
|
712 |
+
else partial(CrossEntropyLoss, inplace_backward=True)
|
713 |
+
)
|
714 |
+
|
715 |
+
self.bert = BertModel(config)
|
716 |
+
self.cls = BertPreTrainingHeads(config)
|
717 |
+
self.mlm_loss = loss_cls(ignore_index=0)
|
718 |
+
|
719 |
+
# Initialize weights and apply final processing
|
720 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
721 |
+
self.tie_weights()
|
722 |
+
|
723 |
+
def tie_weights(self):
|
724 |
+
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
725 |
+
|
726 |
+
def get_input_embeddings(self):
|
727 |
+
return self.bert.embeddings.word_embeddings
|
728 |
+
|
729 |
+
def forward(
|
730 |
+
self,
|
731 |
+
input_ids,
|
732 |
+
position_ids=None,
|
733 |
+
token_type_ids=None,
|
734 |
+
attention_mask=None,
|
735 |
+
labels=None
|
736 |
+
):
|
737 |
+
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
738 |
+
outputs = self.bert(
|
739 |
+
input_ids,
|
740 |
+
position_ids=position_ids,
|
741 |
+
token_type_ids=token_type_ids,
|
742 |
+
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
743 |
+
masked_tokens_mask=masked_tokens_mask,
|
744 |
+
)
|
745 |
+
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
746 |
+
if self.dense_seq_output and labels is not None:
|
747 |
+
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
748 |
+
if not self.last_layer_subset:
|
749 |
+
sequence_output = index_first_axis(
|
750 |
+
rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
751 |
+
)
|
752 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
753 |
+
|
754 |
+
if (
|
755 |
+
self.dense_seq_output and labels is not None
|
756 |
+
): # prediction_scores are already flattened
|
757 |
+
masked_lm_loss = self.mlm_loss(
|
758 |
+
prediction_scores, labels.flatten()[masked_token_idx]
|
759 |
+
).float()
|
760 |
+
|
761 |
+
assert labels is not None
|
762 |
+
|
763 |
+
masked_lm_loss = self.mlm_loss(
|
764 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
765 |
+
rearrange(labels, "... -> (...)"),
|
766 |
+
).float()
|
767 |
+
|
768 |
+
return BertForPreTrainingOutput(
|
769 |
+
loss=masked_lm_loss,
|
770 |
+
prediction_logits=prediction_scores,
|
771 |
+
seq_relationship_logits=seq_relationship_score,
|
772 |
)
|