Update AbLang_bert_model.py
Browse files- AbLang_bert_model.py +1 -79
AbLang_bert_model.py
CHANGED
@@ -31,82 +31,4 @@ class BertEmbeddingsV2(BertEmbeddings):
|
|
31 |
class BertModelV2(BertModel):
|
32 |
def __init__(self, config):
|
33 |
super().__init__(config)
|
34 |
-
self.embeddings = BertEmbeddingsV2(config)
|
35 |
-
|
36 |
-
|
37 |
-
class BertForMaskedLMV2(BertForMaskedLM):
|
38 |
-
def __init__(self, config):
|
39 |
-
super().__init__(config)
|
40 |
-
|
41 |
-
def forward(
|
42 |
-
self,
|
43 |
-
input_ids: Optional[torch.Tensor] = None,
|
44 |
-
attention_mask: Optional[torch.Tensor] = None,
|
45 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
46 |
-
position_ids: Optional[torch.Tensor] = None,
|
47 |
-
head_mask: Optional[torch.Tensor] = None,
|
48 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
49 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
50 |
-
encoder_attention_mask: Optional[torch.Tensor] = None,
|
51 |
-
labels: Optional[torch.Tensor] = None,
|
52 |
-
output_attentions: Optional[bool] = None,
|
53 |
-
output_hidden_states: Optional[bool] = None,
|
54 |
-
return_dict: Optional[bool] = None,
|
55 |
-
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
56 |
-
r"""
|
57 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
58 |
-
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
59 |
-
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
60 |
-
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
61 |
-
"""
|
62 |
-
|
63 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
64 |
-
|
65 |
-
outputs = self.bert(
|
66 |
-
input_ids,
|
67 |
-
attention_mask=attention_mask,
|
68 |
-
token_type_ids=token_type_ids,
|
69 |
-
position_ids=position_ids,
|
70 |
-
head_mask=head_mask,
|
71 |
-
inputs_embeds=inputs_embeds,
|
72 |
-
encoder_hidden_states=encoder_hidden_states,
|
73 |
-
encoder_attention_mask=encoder_attention_mask,
|
74 |
-
output_attentions=output_attentions,
|
75 |
-
output_hidden_states=output_hidden_states,
|
76 |
-
return_dict=return_dict,
|
77 |
-
)
|
78 |
-
|
79 |
-
sequence_output = outputs[0]
|
80 |
-
prediction_scores = sequence_output[:, :, 0:24]
|
81 |
-
|
82 |
-
masked_lm_loss = None
|
83 |
-
if labels is not None:
|
84 |
-
loss_fct = torch.nn.CrossEntropyLoss() # -100 index = padding token
|
85 |
-
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
86 |
-
|
87 |
-
if not return_dict:
|
88 |
-
output = (prediction_scores,) + outputs[2:]
|
89 |
-
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
90 |
-
|
91 |
-
return MaskedLMOutput(
|
92 |
-
loss=masked_lm_loss,
|
93 |
-
logits=prediction_scores,
|
94 |
-
hidden_states=outputs.hidden_states,
|
95 |
-
attentions=outputs.attentions,
|
96 |
-
)
|
97 |
-
|
98 |
-
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
99 |
-
input_shape = input_ids.shape
|
100 |
-
effective_batch_size = input_shape[0]
|
101 |
-
|
102 |
-
# add a dummy token
|
103 |
-
if self.config.pad_token_id is None:
|
104 |
-
raise ValueError("The PAD token should be defined for generation")
|
105 |
-
|
106 |
-
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
107 |
-
dummy_token = torch.full(
|
108 |
-
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
109 |
-
)
|
110 |
-
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
111 |
-
|
112 |
-
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
31 |
class BertModelV2(BertModel):
|
32 |
def __init__(self, config):
|
33 |
super().__init__(config)
|
34 |
+
self.embeddings = BertEmbeddingsV2(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|