koukandre commited on
Commit
4f2b80b
1 Parent(s): f0925f9

Update modeling_xlm_roberta_for_glue.py

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta_for_glue.py +9 -4
modeling_xlm_roberta_for_glue.py CHANGED
@@ -6,16 +6,16 @@ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
  from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
 
8
  from .modeling_bert import XLMRobertaPreTrainedModel, XLMRobertaModel
9
- from .configuration_bert import JinaBertConfig
10
 
11
 
12
  class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
- def __init__(self, config: JinaBertConfig):
14
  super().__init__(config)
15
  self.num_labels = config.num_labels
16
  self.config = config
17
 
18
- self.bert = XLMRobertaModel(config)
19
  classifier_dropout = (
20
  config.classifier_dropout
21
  if config.classifier_dropout is not None
@@ -56,11 +56,16 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
56
  assert output_attentions is None
57
  assert output_hidden_states is None
58
  assert return_dict
59
- outputs = self.bert(
60
  input_ids,
61
  attention_mask=attention_mask,
62
  token_type_ids=token_type_ids,
63
  position_ids=position_ids,
 
 
 
 
 
64
  )
65
 
66
  pooled_output = outputs[1]
 
6
  from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
 
8
  from .modeling_bert import XLMRobertaPreTrainedModel, XLMRobertaModel
9
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
 
11
 
12
  class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
+ def __init__(self, config: XLMRobertaFlashConfig):
14
  super().__init__(config)
15
  self.num_labels = config.num_labels
16
  self.config = config
17
 
18
+ self.roberta = XLMRobertaModel(config)
19
  classifier_dropout = (
20
  config.classifier_dropout
21
  if config.classifier_dropout is not None
 
56
  assert output_attentions is None
57
  assert output_hidden_states is None
58
  assert return_dict
59
+ outputs = self.roberta(
60
  input_ids,
61
  attention_mask=attention_mask,
62
  token_type_ids=token_type_ids,
63
  position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
  )
70
 
71
  pooled_output = outputs[1]