Create modeling_xlm_roberta_for_glue.py

#4
Files changed (1) hide show
  1. modeling_xlm_roberta_for_glue.py +109 -0
modeling_xlm_roberta_for_glue.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
+
8
+ from .modeling_bert import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
+
11
+
12
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+ self.config = config
17
+
18
+ self.roberta = XLMRobertaModel(config)
19
+ classifier_dropout = (
20
+ config.classifier_dropout
21
+ if config.classifier_dropout is not None
22
+ else config.hidden_dropout_prob
23
+ )
24
+ self.dropout = nn.Dropout(classifier_dropout)
25
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ # Initialize weights and apply final processing
28
+ self.post_init()
29
+
30
+
31
+ def forward(
32
+ self,
33
+ input_ids: Optional[torch.Tensor] = None,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ token_type_ids: Optional[torch.Tensor] = None,
36
+ position_ids: Optional[torch.Tensor] = None,
37
+ head_mask: Optional[torch.Tensor] = None,
38
+ inputs_embeds: Optional[torch.Tensor] = None,
39
+ labels: Optional[torch.Tensor] = None,
40
+ output_attentions: Optional[bool] = None,
41
+ output_hidden_states: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = (
51
+ return_dict if return_dict is not None else self.config.use_return_dict
52
+ )
53
+
54
+ assert head_mask is None
55
+ assert inputs_embeds is None
56
+ assert output_attentions is None
57
+ assert output_hidden_states is None
58
+ assert return_dict
59
+ outputs = self.roberta(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ pooled_output = outputs[1]
72
+
73
+ pooled_output = self.dropout(pooled_output)
74
+ logits = self.classifier(pooled_output)
75
+
76
+ loss = None
77
+ if labels is not None:
78
+ if self.config.problem_type is None:
79
+ if self.num_labels == 1:
80
+ self.config.problem_type = "regression"
81
+ elif self.num_labels > 1 and (
82
+ labels.dtype == torch.long or labels.dtype == torch.int
83
+ ):
84
+ self.config.problem_type = "single_label_classification"
85
+ else:
86
+ self.config.problem_type = "multi_label_classification"
87
+
88
+ if self.config.problem_type == "regression":
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(logits, labels)
94
+ elif self.config.problem_type == "single_label_classification":
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == "multi_label_classification":
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(logits, labels)
100
+ if not return_dict:
101
+ output = (logits,) + outputs[2:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions,
109
+ )