omarmomen commited on
Commit
613d4d7
1 Parent(s): 2e0b1b9
Files changed (3) hide show
  1. config.json +39 -0
  2. modeling_structroberta.py +2146 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructRoberta"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_structroberta.StructRobertaConfig",
8
+ "AutoModelForMaskedLM": "modeling_structroberta.StructRoberta"
9
+ },
10
+ "bos_token_id": 0,
11
+ "classifier_dropout": null,
12
+ "conv_size": 9,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "layer_norm_eps": 1e-05,
20
+ "max_position_embeddings": 514,
21
+ "model_type": "roberta",
22
+ "n_cntxt_layers": 4,
23
+ "n_cntxt_layers_2": 0,
24
+ "n_parser_layers": 6,
25
+ "num_attention_heads": 12,
26
+ "num_hidden_layers": 8,
27
+ "pad_token_id": 1,
28
+ "position_embedding_type": "absolute",
29
+ "relations": [
30
+ "head",
31
+ "child"
32
+ ],
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.18.0",
35
+ "type_vocab_size": 1,
36
+ "use_cache": true,
37
+ "vocab_size": 32000,
38
+ "weight_act": "softmax"
39
+ }
modeling_structroberta.py ADDED
@@ -0,0 +1,2146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import List, Optional, Tuple, Union
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from packaging import version
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.activations import ACT2FN, gelu
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPastAndCrossAttentions,
13
+ BaseModelOutputWithPoolingAndCrossAttentions,
14
+ MaskedLMOutput,
15
+ SequenceClassifierOutput
16
+ )
17
+ from transformers.modeling_utils import (
18
+ PreTrainedModel,
19
+ apply_chunking_to_forward,
20
+ find_pruneable_heads_and_indices,
21
+ prune_linear_layer,
22
+ )
23
+ from transformers.utils import logging
24
+ from transformers import RobertaConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
29
+ "roberta-base",
30
+ "roberta-large",
31
+ "roberta-large-mnli",
32
+ "distilroberta-base",
33
+ "roberta-base-openai-detector",
34
+ "roberta-large-openai-detector",
35
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
36
+ ]
37
+
38
+
39
+ class StructRobertaConfig(RobertaConfig):
40
+ model_type = "roberta"
41
+
42
+ def __init__(
43
+ self,
44
+ n_parser_layers=4,
45
+ conv_size=9,
46
+ relations=("head", "child"),
47
+ weight_act="softmax",
48
+ n_cntxt_layers=3,
49
+ n_cntxt_layers_2=0,
50
+ **kwargs,):
51
+
52
+ super().__init__(**kwargs)
53
+ self.n_cntxt_layers = n_cntxt_layers
54
+ self.n_parser_layers = n_parser_layers
55
+ self.n_cntxt_layers_2 = n_cntxt_layers_2
56
+ self.conv_size = conv_size
57
+ self.relations = relations
58
+ self.weight_act = weight_act
59
+
60
+ class Conv1d(nn.Module):
61
+ """1D convolution layer."""
62
+
63
+ def __init__(self, hidden_size, kernel_size, dilation=1):
64
+ """Initialization.
65
+
66
+ Args:
67
+ hidden_size: dimension of input embeddings
68
+ kernel_size: convolution kernel size
69
+ dilation: the spacing between the kernel points
70
+ """
71
+ super(Conv1d, self).__init__()
72
+
73
+ if kernel_size % 2 == 0:
74
+ padding = (kernel_size // 2) * dilation
75
+ self.shift = True
76
+ else:
77
+ padding = ((kernel_size - 1) // 2) * dilation
78
+ self.shift = False
79
+ self.conv = nn.Conv1d(
80
+ hidden_size, hidden_size, kernel_size, padding=padding, dilation=dilation
81
+ )
82
+
83
+ def forward(self, x):
84
+ """Compute convolution.
85
+
86
+ Args:
87
+ x: input embeddings
88
+ Returns:
89
+ conv_output: convolution results
90
+ """
91
+
92
+ if self.shift:
93
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
94
+ else:
95
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
96
+
97
+
98
+ class RobertaEmbeddings(nn.Module):
99
+ """
100
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
101
+ """
102
+
103
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
104
+ def __init__(self, config):
105
+ super().__init__()
106
+ self.word_embeddings = nn.Embedding(
107
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
108
+ )
109
+ self.position_embeddings = nn.Embedding(
110
+ config.max_position_embeddings, config.hidden_size
111
+ )
112
+ self.token_type_embeddings = nn.Embedding(
113
+ config.type_vocab_size, config.hidden_size
114
+ )
115
+
116
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
117
+ # any TensorFlow checkpoint file
118
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
119
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
120
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
121
+ self.position_embedding_type = getattr(
122
+ config, "position_embedding_type", "absolute"
123
+ )
124
+ self.register_buffer(
125
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
126
+ )
127
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
128
+ self.register_buffer(
129
+ "token_type_ids",
130
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
131
+ persistent=False,
132
+ )
133
+
134
+ # End copy
135
+ self.padding_idx = config.pad_token_id
136
+ self.position_embeddings = nn.Embedding(
137
+ config.max_position_embeddings,
138
+ config.hidden_size,
139
+ padding_idx=self.padding_idx,
140
+ )
141
+
142
+ def forward(
143
+ self,
144
+ input_ids=None,
145
+ token_type_ids=None,
146
+ position_ids=None,
147
+ inputs_embeds=None,
148
+ past_key_values_length=0,
149
+ ):
150
+ if position_ids is None:
151
+ if input_ids is not None:
152
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
153
+ position_ids = create_position_ids_from_input_ids(
154
+ input_ids, self.padding_idx, past_key_values_length
155
+ )
156
+ else:
157
+ position_ids = self.create_position_ids_from_inputs_embeds(
158
+ inputs_embeds
159
+ )
160
+
161
+ if input_ids is not None:
162
+ input_shape = input_ids.size()
163
+ else:
164
+ input_shape = inputs_embeds.size()[:-1]
165
+
166
+ seq_length = input_shape[1]
167
+
168
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
169
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
170
+ # issue #5664
171
+ if token_type_ids is None:
172
+ if hasattr(self, "token_type_ids"):
173
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
174
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
175
+ input_shape[0], seq_length
176
+ )
177
+ token_type_ids = buffered_token_type_ids_expanded
178
+ else:
179
+ token_type_ids = torch.zeros(
180
+ input_shape, dtype=torch.long, device=self.position_ids.device
181
+ )
182
+
183
+ if inputs_embeds is None:
184
+ inputs_embeds = self.word_embeddings(input_ids)
185
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
186
+
187
+ embeddings = inputs_embeds + token_type_embeddings
188
+ if self.position_embedding_type == "absolute":
189
+ position_embeddings = self.position_embeddings(position_ids)
190
+ embeddings += position_embeddings
191
+ embeddings = self.LayerNorm(embeddings)
192
+ embeddings = self.dropout(embeddings)
193
+ return embeddings
194
+
195
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
196
+ """
197
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
198
+
199
+ Args:
200
+ inputs_embeds: torch.Tensor
201
+
202
+ Returns: torch.Tensor
203
+ """
204
+ input_shape = inputs_embeds.size()[:-1]
205
+ sequence_length = input_shape[1]
206
+
207
+ position_ids = torch.arange(
208
+ self.padding_idx + 1,
209
+ sequence_length + self.padding_idx + 1,
210
+ dtype=torch.long,
211
+ device=inputs_embeds.device,
212
+ )
213
+ return position_ids.unsqueeze(0).expand(input_shape)
214
+
215
+
216
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
217
+ class RobertaSelfAttention(nn.Module):
218
+ def __init__(self, config, position_embedding_type=None):
219
+ super().__init__()
220
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
221
+ config, "embedding_size"
222
+ ):
223
+ raise ValueError(
224
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
225
+ f"heads ({config.num_attention_heads})"
226
+ )
227
+
228
+ self.num_attention_heads = config.num_attention_heads
229
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
230
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
231
+
232
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
233
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
234
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
235
+
236
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
237
+ self.position_embedding_type = position_embedding_type or getattr(
238
+ config, "position_embedding_type", "absolute"
239
+ )
240
+ if (
241
+ self.position_embedding_type == "relative_key"
242
+ or self.position_embedding_type == "relative_key_query"
243
+ ):
244
+ self.max_position_embeddings = config.max_position_embeddings
245
+ self.distance_embedding = nn.Embedding(
246
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
247
+ )
248
+
249
+ self.is_decoder = config.is_decoder
250
+
251
+ def transpose_for_scores(self, x):
252
+ new_x_shape = x.size()[:-1] + (
253
+ self.num_attention_heads,
254
+ self.attention_head_size,
255
+ )
256
+ x = x.view(new_x_shape)
257
+ return x.permute(0, 2, 1, 3)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states: torch.Tensor,
262
+ attention_mask: Optional[torch.FloatTensor] = None,
263
+ head_mask: Optional[torch.FloatTensor] = None,
264
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
265
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
266
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
267
+ output_attentions: Optional[bool] = False,
268
+ parser_att_mask=None,
269
+ ) -> Tuple[torch.Tensor]:
270
+ mixed_query_layer = self.query(hidden_states)
271
+
272
+ # If this is instantiated as a cross-attention module, the keys
273
+ # and values come from an encoder; the attention mask needs to be
274
+ # such that the encoder's padding tokens are not attended to.
275
+ is_cross_attention = encoder_hidden_states is not None
276
+
277
+ if is_cross_attention and past_key_value is not None:
278
+ # reuse k,v, cross_attentions
279
+ key_layer = past_key_value[0]
280
+ value_layer = past_key_value[1]
281
+ attention_mask = encoder_attention_mask
282
+ elif is_cross_attention:
283
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
284
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
285
+ attention_mask = encoder_attention_mask
286
+ elif past_key_value is not None:
287
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
288
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
289
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
290
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
291
+ else:
292
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
293
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
294
+
295
+ query_layer = self.transpose_for_scores(mixed_query_layer)
296
+
297
+ if self.is_decoder:
298
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
299
+ # Further calls to cross_attention layer can then reuse all cross-attention
300
+ # key/value_states (first "if" case)
301
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
302
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
303
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
304
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
305
+ past_key_value = (key_layer, value_layer)
306
+
307
+ # Take the dot product between "query" and "key" to get the raw attention scores.
308
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
309
+
310
+ if (
311
+ self.position_embedding_type == "relative_key"
312
+ or self.position_embedding_type == "relative_key_query"
313
+ ):
314
+ seq_length = hidden_states.size()[1]
315
+ position_ids_l = torch.arange(
316
+ seq_length, dtype=torch.long, device=hidden_states.device
317
+ ).view(-1, 1)
318
+ position_ids_r = torch.arange(
319
+ seq_length, dtype=torch.long, device=hidden_states.device
320
+ ).view(1, -1)
321
+ distance = position_ids_l - position_ids_r
322
+ positional_embedding = self.distance_embedding(
323
+ distance + self.max_position_embeddings - 1
324
+ )
325
+ positional_embedding = positional_embedding.to(
326
+ dtype=query_layer.dtype
327
+ ) # fp16 compatibility
328
+
329
+ if self.position_embedding_type == "relative_key":
330
+ relative_position_scores = torch.einsum(
331
+ "bhld,lrd->bhlr", query_layer, positional_embedding
332
+ )
333
+ attention_scores = attention_scores + relative_position_scores
334
+ elif self.position_embedding_type == "relative_key_query":
335
+ relative_position_scores_query = torch.einsum(
336
+ "bhld,lrd->bhlr", query_layer, positional_embedding
337
+ )
338
+ relative_position_scores_key = torch.einsum(
339
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
340
+ )
341
+ attention_scores = (
342
+ attention_scores
343
+ + relative_position_scores_query
344
+ + relative_position_scores_key
345
+ )
346
+
347
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
348
+ if attention_mask is not None:
349
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
350
+ attention_scores = attention_scores + attention_mask
351
+
352
+ if parser_att_mask is None:
353
+ # Normalize the attention scores to probabilities.
354
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
355
+ else:
356
+ attention_probs = torch.sigmoid(attention_scores) * parser_att_mask
357
+
358
+ # This is actually dropping out entire tokens to attend to, which might
359
+ # seem a bit unusual, but is taken from the original Transformer paper.
360
+ attention_probs = self.dropout(attention_probs)
361
+
362
+ # Mask heads if we want to
363
+ if head_mask is not None:
364
+ attention_probs = attention_probs * head_mask
365
+
366
+ context_layer = torch.matmul(attention_probs, value_layer)
367
+
368
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
369
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
370
+ context_layer = context_layer.view(new_context_layer_shape)
371
+
372
+ outputs = (
373
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
374
+ )
375
+
376
+ if self.is_decoder:
377
+ outputs = outputs + (past_key_value,)
378
+ return outputs
379
+
380
+
381
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
382
+ class RobertaSelfOutput(nn.Module):
383
+ def __init__(self, config):
384
+ super().__init__()
385
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
386
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
387
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
388
+
389
+ def forward(
390
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
391
+ ) -> torch.Tensor:
392
+ hidden_states = self.dense(hidden_states)
393
+ hidden_states = self.dropout(hidden_states)
394
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
395
+ return hidden_states
396
+
397
+
398
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
399
+ class RobertaAttention(nn.Module):
400
+ def __init__(self, config, position_embedding_type=None):
401
+ super().__init__()
402
+ self.self = RobertaSelfAttention(
403
+ config, position_embedding_type=position_embedding_type
404
+ )
405
+ self.output = RobertaSelfOutput(config)
406
+ self.pruned_heads = set()
407
+
408
+ def prune_heads(self, heads):
409
+ if len(heads) == 0:
410
+ return
411
+ heads, index = find_pruneable_heads_and_indices(
412
+ heads,
413
+ self.self.num_attention_heads,
414
+ self.self.attention_head_size,
415
+ self.pruned_heads,
416
+ )
417
+
418
+ # Prune linear layers
419
+ self.self.query = prune_linear_layer(self.self.query, index)
420
+ self.self.key = prune_linear_layer(self.self.key, index)
421
+ self.self.value = prune_linear_layer(self.self.value, index)
422
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
423
+
424
+ # Update hyper params and store pruned heads
425
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
426
+ self.self.all_head_size = (
427
+ self.self.attention_head_size * self.self.num_attention_heads
428
+ )
429
+ self.pruned_heads = self.pruned_heads.union(heads)
430
+
431
+ def forward(
432
+ self,
433
+ hidden_states: torch.Tensor,
434
+ attention_mask: Optional[torch.FloatTensor] = None,
435
+ head_mask: Optional[torch.FloatTensor] = None,
436
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
437
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
438
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
439
+ output_attentions: Optional[bool] = False,
440
+ parser_att_mask=None,
441
+ ) -> Tuple[torch.Tensor]:
442
+ self_outputs = self.self(
443
+ hidden_states,
444
+ attention_mask,
445
+ head_mask,
446
+ encoder_hidden_states,
447
+ encoder_attention_mask,
448
+ past_key_value,
449
+ output_attentions,
450
+ parser_att_mask=parser_att_mask,
451
+ )
452
+ attention_output = self.output(self_outputs[0], hidden_states)
453
+ outputs = (attention_output,) + self_outputs[
454
+ 1:
455
+ ] # add attentions if we output them
456
+ return outputs
457
+
458
+
459
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
460
+ class RobertaIntermediate(nn.Module):
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
464
+ if isinstance(config.hidden_act, str):
465
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
466
+ else:
467
+ self.intermediate_act_fn = config.hidden_act
468
+
469
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
470
+ hidden_states = self.dense(hidden_states)
471
+ hidden_states = self.intermediate_act_fn(hidden_states)
472
+ return hidden_states
473
+
474
+
475
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
476
+ class RobertaOutput(nn.Module):
477
+ def __init__(self, config):
478
+ super().__init__()
479
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
480
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
481
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
482
+
483
+ def forward(
484
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
485
+ ) -> torch.Tensor:
486
+ hidden_states = self.dense(hidden_states)
487
+ hidden_states = self.dropout(hidden_states)
488
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
489
+ return hidden_states
490
+
491
+
492
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
493
+ class RobertaLayer(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
497
+ self.seq_len_dim = 1
498
+ self.attention = RobertaAttention(config)
499
+ self.is_decoder = config.is_decoder
500
+ self.add_cross_attention = config.add_cross_attention
501
+ if self.add_cross_attention:
502
+ if not self.is_decoder:
503
+ raise ValueError(
504
+ f"{self} should be used as a decoder model if cross attention is added"
505
+ )
506
+ self.crossattention = RobertaAttention(
507
+ config, position_embedding_type="absolute"
508
+ )
509
+ self.intermediate = RobertaIntermediate(config)
510
+ self.output = RobertaOutput(config)
511
+
512
+ def forward(
513
+ self,
514
+ hidden_states: torch.Tensor,
515
+ attention_mask: Optional[torch.FloatTensor] = None,
516
+ head_mask: Optional[torch.FloatTensor] = None,
517
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
518
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
519
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
520
+ output_attentions: Optional[bool] = False,
521
+ parser_att_mask=None,
522
+ ) -> Tuple[torch.Tensor]:
523
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
524
+ self_attn_past_key_value = (
525
+ past_key_value[:2] if past_key_value is not None else None
526
+ )
527
+ self_attention_outputs = self.attention(
528
+ hidden_states,
529
+ attention_mask,
530
+ head_mask,
531
+ output_attentions=output_attentions,
532
+ past_key_value=self_attn_past_key_value,
533
+ parser_att_mask=parser_att_mask,
534
+ )
535
+ attention_output = self_attention_outputs[0]
536
+
537
+ # if decoder, the last output is tuple of self-attn cache
538
+ if self.is_decoder:
539
+ outputs = self_attention_outputs[1:-1]
540
+ present_key_value = self_attention_outputs[-1]
541
+ else:
542
+ outputs = self_attention_outputs[
543
+ 1:
544
+ ] # add self attentions if we output attention weights
545
+
546
+ cross_attn_present_key_value = None
547
+ if self.is_decoder and encoder_hidden_states is not None:
548
+ if not hasattr(self, "crossattention"):
549
+ raise ValueError(
550
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
551
+ )
552
+
553
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
554
+ cross_attn_past_key_value = (
555
+ past_key_value[-2:] if past_key_value is not None else None
556
+ )
557
+ cross_attention_outputs = self.crossattention(
558
+ attention_output,
559
+ attention_mask,
560
+ head_mask,
561
+ encoder_hidden_states,
562
+ encoder_attention_mask,
563
+ cross_attn_past_key_value,
564
+ output_attentions,
565
+ )
566
+ attention_output = cross_attention_outputs[0]
567
+ outputs = (
568
+ outputs + cross_attention_outputs[1:-1]
569
+ ) # add cross attentions if we output attention weights
570
+
571
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
572
+ cross_attn_present_key_value = cross_attention_outputs[-1]
573
+ present_key_value = present_key_value + cross_attn_present_key_value
574
+
575
+ layer_output = apply_chunking_to_forward(
576
+ self.feed_forward_chunk,
577
+ self.chunk_size_feed_forward,
578
+ self.seq_len_dim,
579
+ attention_output,
580
+ )
581
+ outputs = (layer_output,) + outputs
582
+
583
+ # if decoder, return the attn key/values as the last output
584
+ if self.is_decoder:
585
+ outputs = outputs + (present_key_value,)
586
+
587
+ return outputs
588
+
589
+ def feed_forward_chunk(self, attention_output):
590
+ intermediate_output = self.intermediate(attention_output)
591
+ layer_output = self.output(intermediate_output, attention_output)
592
+ return layer_output
593
+
594
+
595
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
596
+ class RobertaEncoder(nn.Module):
597
+ def __init__(self, config):
598
+ super().__init__()
599
+ self.config = config
600
+ self.layer = nn.ModuleList(
601
+ [RobertaLayer(config) for _ in range(config.num_hidden_layers)]
602
+ )
603
+ self.gradient_checkpointing = False
604
+
605
+ def forward(
606
+ self,
607
+ hidden_states: torch.Tensor,
608
+ attention_mask: Optional[torch.FloatTensor] = None,
609
+ head_mask: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
612
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
613
+ use_cache: Optional[bool] = None,
614
+ output_attentions: Optional[bool] = False,
615
+ output_hidden_states: Optional[bool] = False,
616
+ return_dict: Optional[bool] = True,
617
+ parser_att_mask=None,
618
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
619
+ all_hidden_states = () if output_hidden_states else None
620
+ all_self_attentions = () if output_attentions else None
621
+ all_cross_attentions = (
622
+ () if output_attentions and self.config.add_cross_attention else None
623
+ )
624
+
625
+ next_decoder_cache = () if use_cache else None
626
+ for i, layer_module in enumerate(self.layer):
627
+ if output_hidden_states:
628
+ all_hidden_states = all_hidden_states + (hidden_states,)
629
+
630
+ layer_head_mask = head_mask[i] if head_mask is not None else None
631
+ past_key_value = past_key_values[i] if past_key_values is not None else None
632
+
633
+ if self.gradient_checkpointing and self.training:
634
+
635
+ if use_cache:
636
+ logger.warning(
637
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
638
+ )
639
+ use_cache = False
640
+
641
+ def create_custom_forward(module):
642
+ def custom_forward(*inputs):
643
+ return module(*inputs, past_key_value, output_attentions)
644
+
645
+ return custom_forward
646
+
647
+ layer_outputs = torch.utils.checkpoint.checkpoint(
648
+ create_custom_forward(layer_module),
649
+ hidden_states,
650
+ attention_mask,
651
+ layer_head_mask,
652
+ encoder_hidden_states,
653
+ encoder_attention_mask,
654
+ )
655
+ else:
656
+ if parser_att_mask is not None:
657
+ layer_outputs = layer_module(
658
+ hidden_states,
659
+ attention_mask,
660
+ layer_head_mask,
661
+ encoder_hidden_states,
662
+ encoder_attention_mask,
663
+ past_key_value,
664
+ output_attentions,
665
+ parser_att_mask=parser_att_mask[i])
666
+ else:
667
+ layer_outputs = layer_module(
668
+ hidden_states,
669
+ attention_mask,
670
+ layer_head_mask,
671
+ encoder_hidden_states,
672
+ encoder_attention_mask,
673
+ past_key_value,
674
+ output_attentions,
675
+ parser_att_mask=None)
676
+
677
+
678
+ hidden_states = layer_outputs[0]
679
+ if use_cache:
680
+ next_decoder_cache += (layer_outputs[-1],)
681
+ if output_attentions:
682
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
683
+ if self.config.add_cross_attention:
684
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
685
+
686
+ if output_hidden_states:
687
+ all_hidden_states = all_hidden_states + (hidden_states,)
688
+
689
+ if not return_dict:
690
+ return tuple(
691
+ v
692
+ for v in [
693
+ hidden_states,
694
+ next_decoder_cache,
695
+ all_hidden_states,
696
+ all_self_attentions,
697
+ all_cross_attentions,
698
+ ]
699
+ if v is not None
700
+ )
701
+ return BaseModelOutputWithPastAndCrossAttentions(
702
+ last_hidden_state=hidden_states,
703
+ past_key_values=next_decoder_cache,
704
+ hidden_states=all_hidden_states,
705
+ attentions=all_self_attentions,
706
+ cross_attentions=all_cross_attentions,
707
+ )
708
+
709
+
710
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
711
+ class RobertaPooler(nn.Module):
712
+ def __init__(self, config):
713
+ super().__init__()
714
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
715
+ self.activation = nn.Tanh()
716
+
717
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
718
+ # We "pool" the model by simply taking the hidden state corresponding
719
+ # to the first token.
720
+ first_token_tensor = hidden_states[:, 0]
721
+ pooled_output = self.dense(first_token_tensor)
722
+ pooled_output = self.activation(pooled_output)
723
+ return pooled_output
724
+
725
+
726
+ class RobertaPreTrainedModel(PreTrainedModel):
727
+ """
728
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
729
+ models.
730
+ """
731
+
732
+ config_class = RobertaConfig
733
+ base_model_prefix = "roberta"
734
+ supports_gradient_checkpointing = True
735
+
736
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
737
+ def _init_weights(self, module):
738
+ """Initialize the weights"""
739
+ if isinstance(module, nn.Linear):
740
+ # Slightly different from the TF version which uses truncated_normal for initialization
741
+ # cf https://github.com/pytorch/pytorch/pull/5617
742
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
743
+ if module.bias is not None:
744
+ module.bias.data.zero_()
745
+ elif isinstance(module, nn.Embedding):
746
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
747
+ if module.padding_idx is not None:
748
+ module.weight.data[module.padding_idx].zero_()
749
+ elif isinstance(module, nn.LayerNorm):
750
+ if module.bias is not None:
751
+ module.bias.data.zero_()
752
+ module.weight.data.fill_(1.0)
753
+
754
+ def _set_gradient_checkpointing(self, module, value=False):
755
+ if isinstance(module, RobertaEncoder):
756
+ module.gradient_checkpointing = value
757
+
758
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
759
+ """Remove some keys from ignore list"""
760
+ if not config.tie_word_embeddings:
761
+ # must make a new list, or the class variable gets modified!
762
+ self._keys_to_ignore_on_save = [
763
+ k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore
764
+ ]
765
+ self._keys_to_ignore_on_load_missing = [
766
+ k
767
+ for k in self._keys_to_ignore_on_load_missing
768
+ if k not in del_keys_to_ignore
769
+ ]
770
+
771
+
772
+ ROBERTA_START_DOCSTRING = r"""
773
+
774
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
775
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
776
+ etc.)
777
+
778
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
779
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
780
+ and behavior.
781
+
782
+ Parameters:
783
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
784
+ model. Initializing with a config file does not load the weights associated with the model, only the
785
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
786
+ """
787
+
788
+
789
+ ROBERTA_INPUTS_DOCSTRING = r"""
790
+ Args:
791
+ input_ids (`torch.LongTensor` of shape `({0})`):
792
+ Indices of input sequence tokens in the vocabulary.
793
+
794
+ Indices can be obtained using [`RobertaTokenizer`]. See [`PreTrainedTokenizer.encode`] and
795
+ [`PreTrainedTokenizer.__call__`] for details.
796
+
797
+ [What are input IDs?](../glossary#input-ids)
798
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
799
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
800
+
801
+ - 1 for tokens that are **not masked**,
802
+ - 0 for tokens that are **masked**.
803
+
804
+ [What are attention masks?](../glossary#attention-mask)
805
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
806
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
807
+ 1]`:
808
+
809
+ - 0 corresponds to a *sentence A* token,
810
+ - 1 corresponds to a *sentence B* token.
811
+
812
+ [What are token type IDs?](../glossary#token-type-ids)
813
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
814
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
815
+ config.max_position_embeddings - 1]`.
816
+
817
+ [What are position IDs?](../glossary#position-ids)
818
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
819
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
820
+
821
+ - 1 indicates the head is **not masked**,
822
+ - 0 indicates the head is **masked**.
823
+
824
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
825
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
826
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
827
+ model's internal embedding lookup matrix.
828
+ output_attentions (`bool`, *optional*):
829
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
830
+ tensors for more detail.
831
+ output_hidden_states (`bool`, *optional*):
832
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
833
+ more detail.
834
+ return_dict (`bool`, *optional*):
835
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
836
+ """
837
+
838
+
839
+ class RobertaModel(RobertaPreTrainedModel):
840
+ """
841
+
842
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
843
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
844
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
845
+ Kaiser and Illia Polosukhin.
846
+
847
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
848
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
849
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
850
+
851
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
852
+
853
+ """
854
+
855
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
856
+
857
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
858
+ def __init__(self, config, add_pooling_layer=True):
859
+ super().__init__(config)
860
+ self.config = config
861
+
862
+ self.embeddings = RobertaEmbeddings(config)
863
+ self.encoder = RobertaEncoder(config)
864
+
865
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
866
+
867
+ # Initialize weights and apply final processing
868
+ self.post_init()
869
+
870
+ def get_input_embeddings(self):
871
+ return self.embeddings.word_embeddings
872
+
873
+ def set_input_embeddings(self, value):
874
+ self.embeddings.word_embeddings = value
875
+
876
+ def _prune_heads(self, heads_to_prune):
877
+ """
878
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
879
+ class PreTrainedModel
880
+ """
881
+ for layer, heads in heads_to_prune.items():
882
+ self.encoder.layer[layer].attention.prune_heads(heads)
883
+
884
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
885
+ def forward(
886
+ self,
887
+ input_ids: Optional[torch.Tensor] = None,
888
+ attention_mask: Optional[torch.Tensor] = None,
889
+ token_type_ids: Optional[torch.Tensor] = None,
890
+ position_ids: Optional[torch.Tensor] = None,
891
+ head_mask: Optional[torch.Tensor] = None,
892
+ inputs_embeds: Optional[torch.Tensor] = None,
893
+ encoder_hidden_states: Optional[torch.Tensor] = None,
894
+ encoder_attention_mask: Optional[torch.Tensor] = None,
895
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
896
+ use_cache: Optional[bool] = None,
897
+ output_attentions: Optional[bool] = None,
898
+ output_hidden_states: Optional[bool] = None,
899
+ return_dict: Optional[bool] = None,
900
+ parser_att_mask=None,
901
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
902
+ r"""
903
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
904
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
905
+ the model is configured as a decoder.
906
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
907
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
908
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
909
+
910
+ - 1 for tokens that are **not masked**,
911
+ - 0 for tokens that are **masked**.
912
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
913
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
914
+
915
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
916
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
917
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
918
+ use_cache (`bool`, *optional*):
919
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
920
+ `past_key_values`).
921
+ """
922
+ output_attentions = (
923
+ output_attentions
924
+ if output_attentions is not None
925
+ else self.config.output_attentions
926
+ )
927
+ output_hidden_states = (
928
+ output_hidden_states
929
+ if output_hidden_states is not None
930
+ else self.config.output_hidden_states
931
+ )
932
+ return_dict = (
933
+ return_dict if return_dict is not None else self.config.use_return_dict
934
+ )
935
+
936
+ if self.config.is_decoder:
937
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
938
+ else:
939
+ use_cache = False
940
+
941
+ if input_ids is not None and inputs_embeds is not None:
942
+ raise ValueError(
943
+ "You cannot specify both input_ids and inputs_embeds at the same time"
944
+ )
945
+ elif input_ids is not None:
946
+ input_shape = input_ids.size()
947
+ elif inputs_embeds is not None:
948
+ input_shape = inputs_embeds.size()[:-1]
949
+ else:
950
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
951
+
952
+ batch_size, seq_length = input_shape
953
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
954
+
955
+ # past_key_values_length
956
+ past_key_values_length = (
957
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
958
+ )
959
+
960
+ if attention_mask is None:
961
+ attention_mask = torch.ones(
962
+ ((batch_size, seq_length + past_key_values_length)), device=device
963
+ )
964
+
965
+ if token_type_ids is None:
966
+ if hasattr(self.embeddings, "token_type_ids"):
967
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
968
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
969
+ batch_size, seq_length
970
+ )
971
+ token_type_ids = buffered_token_type_ids_expanded
972
+ else:
973
+ token_type_ids = torch.zeros(
974
+ input_shape, dtype=torch.long, device=device
975
+ )
976
+
977
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
978
+ # ourselves in which case we just need to make it broadcastable to all heads.
979
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
980
+ attention_mask, input_shape, device
981
+ )
982
+
983
+ # If a 2D or 3D attention mask is provided for the cross-attention
984
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
985
+ if self.config.is_decoder and encoder_hidden_states is not None:
986
+ (
987
+ encoder_batch_size,
988
+ encoder_sequence_length,
989
+ _,
990
+ ) = encoder_hidden_states.size()
991
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
992
+ if encoder_attention_mask is None:
993
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
994
+ encoder_extended_attention_mask = self.invert_attention_mask(
995
+ encoder_attention_mask
996
+ )
997
+ else:
998
+ encoder_extended_attention_mask = None
999
+
1000
+ # Prepare head mask if needed
1001
+ # 1.0 in head_mask indicate we keep the head
1002
+ # attention_probs has shape bsz x n_heads x N x N
1003
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1004
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1005
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1006
+
1007
+ embedding_output = self.embeddings(
1008
+ input_ids=input_ids,
1009
+ position_ids=position_ids,
1010
+ token_type_ids=token_type_ids,
1011
+ inputs_embeds=inputs_embeds,
1012
+ past_key_values_length=past_key_values_length,
1013
+ )
1014
+ encoder_outputs = self.encoder(
1015
+ embedding_output,
1016
+ attention_mask=extended_attention_mask,
1017
+ head_mask=head_mask,
1018
+ encoder_hidden_states=encoder_hidden_states,
1019
+ encoder_attention_mask=encoder_extended_attention_mask,
1020
+ past_key_values=past_key_values,
1021
+ use_cache=use_cache,
1022
+ output_attentions=output_attentions,
1023
+ output_hidden_states=output_hidden_states,
1024
+ return_dict=return_dict,
1025
+ parser_att_mask=parser_att_mask,
1026
+ )
1027
+ sequence_output = encoder_outputs[0]
1028
+ pooled_output = (
1029
+ self.pooler(sequence_output) if self.pooler is not None else None
1030
+ )
1031
+
1032
+ if not return_dict:
1033
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1034
+
1035
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1036
+ last_hidden_state=sequence_output,
1037
+ pooler_output=pooled_output,
1038
+ past_key_values=encoder_outputs.past_key_values,
1039
+ hidden_states=encoder_outputs.hidden_states,
1040
+ attentions=encoder_outputs.attentions,
1041
+ cross_attentions=encoder_outputs.cross_attentions,
1042
+ )
1043
+
1044
+
1045
+ class StructRoberta(RobertaPreTrainedModel):
1046
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1047
+ _keys_to_ignore_on_load_missing = [
1048
+ r"position_ids",
1049
+ r"lm_head.decoder.weight",
1050
+ r"lm_head.decoder.bias",
1051
+ ]
1052
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1053
+
1054
+ def __init__(self, config):
1055
+ super().__init__(config)
1056
+
1057
+ if config.is_decoder:
1058
+ logger.warning(
1059
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
1060
+ "bi-directional self-attention."
1061
+ )
1062
+
1063
+
1064
+ if config.n_cntxt_layers > 0:
1065
+ config_cntxt = copy.deepcopy(config)
1066
+ config_cntxt.num_hidden_layers = config.n_cntxt_layers
1067
+
1068
+ self.cntxt_layers = RobertaModel(config_cntxt, add_pooling_layer=False)
1069
+
1070
+ if config.n_cntxt_layers_2 > 0:
1071
+ self.parser_layers_1 = nn.ModuleList(
1072
+ [
1073
+ nn.Sequential(
1074
+ Conv1d(config.hidden_size, config.conv_size),
1075
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1076
+ nn.Tanh(),
1077
+ )
1078
+ for i in range(int(config.n_parser_layers/2))
1079
+ ]
1080
+ )
1081
+
1082
+ self.distance_ff_1 = nn.Sequential(
1083
+ Conv1d(config.hidden_size, 2),
1084
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1085
+ nn.Tanh(),
1086
+ nn.Linear(config.hidden_size, 1),
1087
+ )
1088
+
1089
+ self.height_ff_1 = nn.Sequential(
1090
+ nn.Linear(config.hidden_size, config.hidden_size),
1091
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1092
+ nn.Tanh(),
1093
+ nn.Linear(config.hidden_size, 1),
1094
+ )
1095
+
1096
+ n_rel = len(config.relations)
1097
+ self._rel_weight_1 = nn.Parameter(
1098
+ torch.zeros((config.n_cntxt_layers_2, config.num_attention_heads, n_rel))
1099
+ )
1100
+ self._rel_weight_1.data.normal_(0, 0.1)
1101
+
1102
+ self._scaler_1 = nn.Parameter(torch.zeros(2))
1103
+
1104
+ config_cntxt_2 = copy.deepcopy(config)
1105
+ config_cntxt_2.num_hidden_layers = config.n_cntxt_layers_2
1106
+
1107
+ self.cntxt_layers_2 = RobertaModel(config_cntxt_2, add_pooling_layer=False)
1108
+
1109
+
1110
+ self.parser_layers_2 = nn.ModuleList(
1111
+ [
1112
+ nn.Sequential(
1113
+ Conv1d(config.hidden_size, config.conv_size),
1114
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1115
+ nn.Tanh(),
1116
+ )
1117
+ for i in range(int(config.n_parser_layers/2))
1118
+ ]
1119
+ )
1120
+
1121
+ self.distance_ff_2 = nn.Sequential(
1122
+ Conv1d(config.hidden_size, 2),
1123
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1124
+ nn.Tanh(),
1125
+ nn.Linear(config.hidden_size, 1),
1126
+ )
1127
+
1128
+ self.height_ff_2 = nn.Sequential(
1129
+ nn.Linear(config.hidden_size, config.hidden_size),
1130
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1131
+ nn.Tanh(),
1132
+ nn.Linear(config.hidden_size, 1),
1133
+ )
1134
+
1135
+ n_rel = len(config.relations)
1136
+ self._rel_weight_2 = nn.Parameter(
1137
+ torch.zeros((config.num_hidden_layers, config.num_attention_heads, n_rel))
1138
+ )
1139
+ self._rel_weight_2.data.normal_(0, 0.1)
1140
+
1141
+ self._scaler_2 = nn.Parameter(torch.zeros(2))
1142
+
1143
+ else:
1144
+ self.parser_layers = nn.ModuleList(
1145
+ [
1146
+ nn.Sequential(
1147
+ Conv1d(config.hidden_size, config.conv_size),
1148
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1149
+ nn.Tanh(),
1150
+ )
1151
+ for i in range(config.n_parser_layers)
1152
+ ]
1153
+ )
1154
+
1155
+ self.distance_ff = nn.Sequential(
1156
+ Conv1d(config.hidden_size, 2),
1157
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1158
+ nn.Tanh(),
1159
+ nn.Linear(config.hidden_size, 1),
1160
+ )
1161
+
1162
+ self.height_ff = nn.Sequential(
1163
+ nn.Linear(config.hidden_size, config.hidden_size),
1164
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1165
+ nn.Tanh(),
1166
+ nn.Linear(config.hidden_size, 1),
1167
+ )
1168
+
1169
+ n_rel = len(config.relations)
1170
+ self._rel_weight = nn.Parameter(
1171
+ torch.zeros((config.num_hidden_layers, config.num_attention_heads, n_rel))
1172
+ )
1173
+ self._rel_weight.data.normal_(0, 0.1)
1174
+
1175
+ self._scaler = nn.Parameter(torch.zeros(2))
1176
+
1177
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1178
+
1179
+ if config.n_cntxt_layers > 0:
1180
+ self.cntxt_layers.embeddings = self.roberta.embeddings
1181
+ if config.n_cntxt_layers_2 > 0:
1182
+ self.cntxt_layers_2.embeddings = self.roberta.embeddings
1183
+
1184
+ self.lm_head = RobertaLMHead(config)
1185
+
1186
+ self.pad = config.pad_token_id
1187
+
1188
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1189
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1190
+
1191
+ # Initialize weights and apply final processing
1192
+ self.post_init()
1193
+
1194
+ def get_output_embeddings(self):
1195
+ return self.lm_head.decoder
1196
+
1197
+ def set_output_embeddings(self, new_embeddings):
1198
+ self.lm_head.decoder = new_embeddings
1199
+
1200
+ @property
1201
+ def scaler(self):
1202
+ return self._scaler.exp()
1203
+
1204
+ @property
1205
+ def scaler_1(self):
1206
+ return self._scaler_1.exp()
1207
+
1208
+ @property
1209
+ def scaler_2(self):
1210
+ return self._scaler_2.exp()
1211
+
1212
+ @property
1213
+ def rel_weight(self):
1214
+ if self.config.weight_act == "sigmoid":
1215
+ return torch.sigmoid(self._rel_weight)
1216
+ elif self.config.weight_act == "softmax":
1217
+ return torch.softmax(self._rel_weight, dim=-1)
1218
+
1219
+ @property
1220
+ def rel_weight_1(self):
1221
+ if self.config.weight_act == "sigmoid":
1222
+ return torch.sigmoid(self._rel_weight_1)
1223
+ elif self.config.weight_act == "softmax":
1224
+ return torch.softmax(self._rel_weight_1, dim=-1)
1225
+
1226
+
1227
+ @property
1228
+ def rel_weight_2(self):
1229
+ if self.config.weight_act == "sigmoid":
1230
+ return torch.sigmoid(self._rel_weight_2)
1231
+ elif self.config.weight_act == "softmax":
1232
+ return torch.softmax(self._rel_weight_2, dim=-1)
1233
+
1234
+
1235
+ def compute_block(self, distance, height, n_cntxt_layers=0):
1236
+ """Compute constituents from distance and height."""
1237
+
1238
+ if n_cntxt_layers>0:
1239
+ if n_cntxt_layers == 1:
1240
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler_1[0]
1241
+ elif n_cntxt_layers == 2:
1242
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler_2[0]
1243
+ else:
1244
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
1245
+
1246
+ gamma = torch.sigmoid(-beta_logits)
1247
+ ones = torch.ones_like(gamma)
1248
+
1249
+ block_mask_left = cummin(
1250
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1
1251
+ )
1252
+ block_mask_left = block_mask_left - F.pad(
1253
+ block_mask_left[:, :, :-1], (1, 0), value=0
1254
+ )
1255
+ block_mask_left.tril_(0)
1256
+
1257
+ block_mask_right = cummin(
1258
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1
1259
+ )
1260
+ block_mask_right = block_mask_right - F.pad(
1261
+ block_mask_right[:, :, 1:], (0, 1), value=0
1262
+ )
1263
+ block_mask_right.triu_(0)
1264
+
1265
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
1266
+ block = cumsum(block_mask_left).tril(0) + cumsum(
1267
+ block_mask_right, reverse=True
1268
+ ).triu(1)
1269
+
1270
+ return block_p, block
1271
+
1272
+ def compute_head(self, height, n_cntxt_layers=0):
1273
+ """Estimate head for each constituent."""
1274
+
1275
+ _, length = height.size()
1276
+ if n_cntxt_layers>0:
1277
+ if n_cntxt_layers == 1:
1278
+ head_logits = height * self.scaler_1[1]
1279
+ elif n_cntxt_layers == 2:
1280
+ head_logits = height * self.scaler_2[1]
1281
+ else:
1282
+ head_logits = height * self.scaler[1]
1283
+ index = torch.arange(length, device=height.device)
1284
+
1285
+ mask = (index[:, None, None] <= index[None, None, :]) * (
1286
+ index[None, None, :] <= index[None, :, None]
1287
+ )
1288
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
1289
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
1290
+
1291
+ head_p = torch.softmax(head_logits, dim=-1)
1292
+
1293
+ return head_p
1294
+
1295
+ def parse(self, x, embs=None, n_cntxt_layers=0):
1296
+ """Parse input sentence.
1297
+
1298
+ Args:
1299
+ x: input tokens (required).
1300
+ pos: position for each token (optional).
1301
+ Returns:
1302
+ distance: syntactic distance
1303
+ height: syntactic height
1304
+ """
1305
+
1306
+ mask = x != self.pad
1307
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
1308
+
1309
+ if embs is None:
1310
+ h = self.roberta.embeddings(x)
1311
+ else:
1312
+ h = embs
1313
+
1314
+ if n_cntxt_layers > 0:
1315
+ if n_cntxt_layers == 1:
1316
+ parser_layers = self.parser_layers_1
1317
+ height_ff = self.height_ff_1
1318
+ distance_ff = self.distance_ff_1
1319
+ elif n_cntxt_layers == 2:
1320
+ parser_layers = self.parser_layers_2
1321
+ height_ff = self.height_ff_2
1322
+ distance_ff = self.distance_ff_2
1323
+ for i in range(int(self.config.n_parser_layers/2)):
1324
+ h = h.masked_fill(~mask[:, :, None], 0)
1325
+ h = parser_layers[i](h)
1326
+
1327
+ height = height_ff(h).squeeze(-1)
1328
+ height.masked_fill_(~mask, -1e9)
1329
+
1330
+ distance = distance_ff(h).squeeze(-1)
1331
+ distance.masked_fill_(~mask_shifted, 1e9)
1332
+
1333
+ # Calbrating the distance and height to the same level
1334
+ length = distance.size(1)
1335
+ height_max = height[:, None, :].expand(-1, length, -1)
1336
+ height_max = torch.cummax(
1337
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9, dim=-1
1338
+ )[0].triu(0)
1339
+
1340
+ margin_left = torch.relu(
1341
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max
1342
+ )
1343
+ margin_right = torch.relu(distance[:, None, :] - height_max)
1344
+ margin = torch.where(
1345
+ margin_left > margin_right, margin_right, margin_left
1346
+ ).triu(0)
1347
+
1348
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
1349
+ margin.masked_fill_(~margin_mask, 0)
1350
+ margin = margin.max()
1351
+
1352
+ distance = distance - margin
1353
+ else:
1354
+ for i in range(self.config.n_parser_layers):
1355
+ h = h.masked_fill(~mask[:, :, None], 0)
1356
+ h = self.parser_layers[i](h)
1357
+
1358
+ height = self.height_ff(h).squeeze(-1)
1359
+ height.masked_fill_(~mask, -1e9)
1360
+
1361
+ distance = self.distance_ff(h).squeeze(-1)
1362
+ distance.masked_fill_(~mask_shifted, 1e9)
1363
+
1364
+ # Calbrating the distance and height to the same level
1365
+ length = distance.size(1)
1366
+ height_max = height[:, None, :].expand(-1, length, -1)
1367
+ height_max = torch.cummax(
1368
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9, dim=-1
1369
+ )[0].triu(0)
1370
+
1371
+ margin_left = torch.relu(
1372
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max
1373
+ )
1374
+ margin_right = torch.relu(distance[:, None, :] - height_max)
1375
+ margin = torch.where(
1376
+ margin_left > margin_right, margin_right, margin_left
1377
+ ).triu(0)
1378
+
1379
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
1380
+ margin.masked_fill_(~margin_mask, 0)
1381
+ margin = margin.max()
1382
+
1383
+ distance = distance - margin
1384
+
1385
+ return distance, height
1386
+
1387
+ def generate_mask(self, x, distance, height, n_cntxt_layers=0):
1388
+ """Compute head and cibling distribution for each token."""
1389
+
1390
+ bsz, length = x.size()
1391
+
1392
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
1393
+ eye = eye[None, :, :].expand((bsz, -1, -1))
1394
+
1395
+ block_p, block = self.compute_block(distance, height, n_cntxt_layers=n_cntxt_layers)
1396
+ head_p = self.compute_head(height, n_cntxt_layers=n_cntxt_layers)
1397
+ head = torch.einsum("blij,bijh->blh", block_p, head_p)
1398
+ head = head.masked_fill(eye, 0)
1399
+ child = head.transpose(1, 2)
1400
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
1401
+
1402
+ rel_list = []
1403
+ if "head" in self.config.relations:
1404
+ rel_list.append(head)
1405
+ if "child" in self.config.relations:
1406
+ rel_list.append(child)
1407
+ if "cibling" in self.config.relations:
1408
+ rel_list.append(cibling)
1409
+
1410
+ rel = torch.stack(rel_list, dim=1)
1411
+
1412
+ if n_cntxt_layers > 0:
1413
+ if n_cntxt_layers == 1:
1414
+ rel_weight = self.rel_weight_1
1415
+ elif n_cntxt_layers == 2:
1416
+ rel_weight = self.rel_weight_2
1417
+ else:
1418
+ rel_weight = self.rel_weight
1419
+
1420
+ dep = torch.einsum("lhr,brij->lbhij", rel_weight, rel)
1421
+
1422
+ if n_cntxt_layers == 1:
1423
+ num_layers = self.cntxt_layers_2.config.num_hidden_layers
1424
+ else:
1425
+ num_layers = self.roberta.config.num_hidden_layers
1426
+
1427
+ att_mask = dep.reshape(
1428
+ num_layers,
1429
+ bsz,
1430
+ self.config.num_attention_heads,
1431
+ length,
1432
+ length,
1433
+ )
1434
+
1435
+ return att_mask, cibling, head, block
1436
+
1437
+ def forward(
1438
+ self,
1439
+ input_ids: Optional[torch.LongTensor] = None,
1440
+ attention_mask: Optional[torch.FloatTensor] = None,
1441
+ token_type_ids: Optional[torch.LongTensor] = None,
1442
+ position_ids: Optional[torch.LongTensor] = None,
1443
+ head_mask: Optional[torch.FloatTensor] = None,
1444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1445
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1446
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1447
+ labels: Optional[torch.LongTensor] = None,
1448
+ output_attentions: Optional[bool] = None,
1449
+ output_hidden_states: Optional[bool] = None,
1450
+ return_dict: Optional[bool] = None,
1451
+ ) -> Union[Tuple, MaskedLMOutput]:
1452
+ r"""
1453
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1454
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1455
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1456
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1457
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1458
+ Used to hide legacy arguments that have been deprecated.
1459
+ """
1460
+ return_dict = (
1461
+ return_dict if return_dict is not None else self.config.use_return_dict
1462
+ )
1463
+
1464
+
1465
+ if self.config.n_cntxt_layers > 0:
1466
+ cntxt_outputs = self.cntxt_layers(
1467
+ input_ids,
1468
+ attention_mask=attention_mask,
1469
+ token_type_ids=token_type_ids,
1470
+ position_ids=position_ids,
1471
+ head_mask=head_mask,
1472
+ inputs_embeds=inputs_embeds,
1473
+ encoder_hidden_states=encoder_hidden_states,
1474
+ encoder_attention_mask=encoder_attention_mask,
1475
+ output_attentions=output_attentions,
1476
+ output_hidden_states=output_hidden_states,
1477
+ return_dict=return_dict)
1478
+
1479
+
1480
+ if self.config.n_cntxt_layers_2 > 0:
1481
+ distance_1, height_1 = self.parse(input_ids, cntxt_outputs[0], n_cntxt_layers=1)
1482
+ att_mask_1, _, _, _ = self.generate_mask(input_ids, distance_1, height_1, n_cntxt_layers=1)
1483
+
1484
+ cntxt_outputs_2 = self.cntxt_layers_2(
1485
+ input_ids,
1486
+ attention_mask=attention_mask,
1487
+ token_type_ids=token_type_ids,
1488
+ position_ids=position_ids,
1489
+ head_mask=head_mask,
1490
+ inputs_embeds=inputs_embeds,
1491
+ encoder_hidden_states=encoder_hidden_states,
1492
+ encoder_attention_mask=encoder_attention_mask,
1493
+ output_attentions=output_attentions,
1494
+ output_hidden_states=output_hidden_states,
1495
+ return_dict=return_dict,
1496
+ parser_att_mask=att_mask_1)
1497
+
1498
+ sequence_output = cntxt_outputs_2[0]
1499
+
1500
+ distance_2, height_2 = self.parse(input_ids, sequence_output[0], n_cntxt_layers=2)
1501
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance_2, height_2, n_cntxt_layers=2)
1502
+
1503
+ elif self.config.n_cntxt_layers > 0:
1504
+ distance, height = self.parse(input_ids, cntxt_outputs[0])
1505
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance, height)
1506
+ else:
1507
+ distance, height = self.parse(input_ids)
1508
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance, height)
1509
+
1510
+ outputs = self.roberta(
1511
+ input_ids,
1512
+ attention_mask=attention_mask,
1513
+ token_type_ids=token_type_ids,
1514
+ position_ids=position_ids,
1515
+ head_mask=head_mask,
1516
+ inputs_embeds=inputs_embeds,
1517
+ encoder_hidden_states=encoder_hidden_states,
1518
+ encoder_attention_mask=encoder_attention_mask,
1519
+ output_attentions=output_attentions,
1520
+ output_hidden_states=output_hidden_states,
1521
+ return_dict=return_dict,
1522
+ parser_att_mask=att_mask,
1523
+ )
1524
+ sequence_output = outputs[0]
1525
+ prediction_scores = self.lm_head(sequence_output)
1526
+
1527
+ masked_lm_loss = None
1528
+ if labels is not None:
1529
+ loss_fct = CrossEntropyLoss()
1530
+ masked_lm_loss = loss_fct(
1531
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1532
+ )
1533
+
1534
+ if not return_dict:
1535
+ output = (prediction_scores,) + outputs[2:]
1536
+ return (
1537
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1538
+ )
1539
+
1540
+ return MaskedLMOutput(
1541
+ loss=masked_lm_loss,
1542
+ logits=prediction_scores,
1543
+ hidden_states=outputs.hidden_states,
1544
+ attentions=outputs.attentions,
1545
+ )
1546
+
1547
+
1548
+ class RobertaLMHead(nn.Module):
1549
+ """Roberta Head for masked language modeling."""
1550
+
1551
+ def __init__(self, config):
1552
+ super().__init__()
1553
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1554
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1555
+
1556
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1557
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1558
+ self.decoder.bias = self.bias
1559
+
1560
+ def forward(self, features, **kwargs):
1561
+ x = self.dense(features)
1562
+ x = gelu(x)
1563
+ x = self.layer_norm(x)
1564
+
1565
+ # project back to size of vocabulary with bias
1566
+ x = self.decoder(x)
1567
+
1568
+ return x
1569
+
1570
+ def _tie_weights(self):
1571
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1572
+ self.bias = self.decoder.bias
1573
+
1574
+
1575
+ class StructRobertaForSequenceClassification(RobertaPreTrainedModel):
1576
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1577
+
1578
+ def __init__(self, config):
1579
+ super().__init__(config)
1580
+ self.num_labels = config.num_labels
1581
+ self.config = config
1582
+
1583
+ if config.n_cntxt_layers > 0:
1584
+ config_cntxt = copy.deepcopy(config)
1585
+ config_cntxt.num_hidden_layers = config.n_cntxt_layers
1586
+
1587
+ self.cntxt_layers = RobertaModel(config_cntxt, add_pooling_layer=False)
1588
+
1589
+ if config.n_cntxt_layers_2 > 0:
1590
+ self.parser_layers_1 = nn.ModuleList(
1591
+ [
1592
+ nn.Sequential(
1593
+ Conv1d(config.hidden_size, config.conv_size),
1594
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1595
+ nn.Tanh(),
1596
+ )
1597
+ for i in range(int(config.n_parser_layers/2))
1598
+ ]
1599
+ )
1600
+
1601
+ self.distance_ff_1 = nn.Sequential(
1602
+ Conv1d(config.hidden_size, 2),
1603
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1604
+ nn.Tanh(),
1605
+ nn.Linear(config.hidden_size, 1),
1606
+ )
1607
+
1608
+ self.height_ff_1 = nn.Sequential(
1609
+ nn.Linear(config.hidden_size, config.hidden_size),
1610
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1611
+ nn.Tanh(),
1612
+ nn.Linear(config.hidden_size, 1),
1613
+ )
1614
+
1615
+ n_rel = len(config.relations)
1616
+ self._rel_weight_1 = nn.Parameter(
1617
+ torch.zeros((config.n_cntxt_layers_2, config.num_attention_heads, n_rel))
1618
+ )
1619
+ self._rel_weight_1.data.normal_(0, 0.1)
1620
+
1621
+ self._scaler_1 = nn.Parameter(torch.zeros(2))
1622
+
1623
+ config_cntxt_2 = copy.deepcopy(config)
1624
+ config_cntxt_2.num_hidden_layers = config.n_cntxt_layers_2
1625
+
1626
+ self.cntxt_layers_2 = RobertaModel(config_cntxt_2, add_pooling_layer=False)
1627
+
1628
+
1629
+ self.parser_layers_2 = nn.ModuleList(
1630
+ [
1631
+ nn.Sequential(
1632
+ Conv1d(config.hidden_size, config.conv_size),
1633
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1634
+ nn.Tanh(),
1635
+ )
1636
+ for i in range(int(config.n_parser_layers/2))
1637
+ ]
1638
+ )
1639
+
1640
+ self.distance_ff_2 = nn.Sequential(
1641
+ Conv1d(config.hidden_size, 2),
1642
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1643
+ nn.Tanh(),
1644
+ nn.Linear(config.hidden_size, 1),
1645
+ )
1646
+
1647
+ self.height_ff_2 = nn.Sequential(
1648
+ nn.Linear(config.hidden_size, config.hidden_size),
1649
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1650
+ nn.Tanh(),
1651
+ nn.Linear(config.hidden_size, 1),
1652
+ )
1653
+
1654
+ n_rel = len(config.relations)
1655
+ self._rel_weight_2 = nn.Parameter(
1656
+ torch.zeros((config.num_hidden_layers, config.num_attention_heads, n_rel))
1657
+ )
1658
+ self._rel_weight_2.data.normal_(0, 0.1)
1659
+
1660
+ self._scaler_2 = nn.Parameter(torch.zeros(2))
1661
+
1662
+ else:
1663
+ self.parser_layers = nn.ModuleList(
1664
+ [
1665
+ nn.Sequential(
1666
+ Conv1d(config.hidden_size, config.conv_size),
1667
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1668
+ nn.Tanh(),
1669
+ )
1670
+ for i in range(config.n_parser_layers)
1671
+ ]
1672
+ )
1673
+
1674
+ self.distance_ff = nn.Sequential(
1675
+ Conv1d(config.hidden_size, 2),
1676
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1677
+ nn.Tanh(),
1678
+ nn.Linear(config.hidden_size, 1),
1679
+ )
1680
+
1681
+ self.height_ff = nn.Sequential(
1682
+ nn.Linear(config.hidden_size, config.hidden_size),
1683
+ nn.LayerNorm(config.hidden_size, elementwise_affine=False),
1684
+ nn.Tanh(),
1685
+ nn.Linear(config.hidden_size, 1),
1686
+ )
1687
+
1688
+ n_rel = len(config.relations)
1689
+ self._rel_weight = nn.Parameter(
1690
+ torch.zeros((config.num_hidden_layers, config.num_attention_heads, n_rel))
1691
+ )
1692
+ self._rel_weight.data.normal_(0, 0.1)
1693
+
1694
+ self._scaler = nn.Parameter(torch.zeros(2))
1695
+
1696
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1697
+
1698
+ if config.n_cntxt_layers > 0:
1699
+ self.cntxt_layers.embeddings = self.roberta.embeddings
1700
+ if config.n_cntxt_layers_2 > 0:
1701
+ self.cntxt_layers_2.embeddings = self.roberta.embeddings
1702
+
1703
+
1704
+ self.pad = config.pad_token_id
1705
+ self.classifier = RobertaClassificationHead(config)
1706
+
1707
+ # Initialize weights and apply final processing
1708
+ self.post_init()
1709
+
1710
+
1711
+ @property
1712
+ def scaler(self):
1713
+ return self._scaler.exp()
1714
+
1715
+ @property
1716
+ def scaler_1(self):
1717
+ return self._scaler_1.exp()
1718
+
1719
+ @property
1720
+ def scaler_2(self):
1721
+ return self._scaler_2.exp()
1722
+
1723
+ @property
1724
+ def rel_weight(self):
1725
+ if self.config.weight_act == "sigmoid":
1726
+ return torch.sigmoid(self._rel_weight)
1727
+ elif self.config.weight_act == "softmax":
1728
+ return torch.softmax(self._rel_weight, dim=-1)
1729
+
1730
+ @property
1731
+ def rel_weight_1(self):
1732
+ if self.config.weight_act == "sigmoid":
1733
+ return torch.sigmoid(self._rel_weight_1)
1734
+ elif self.config.weight_act == "softmax":
1735
+ return torch.softmax(self._rel_weight_1, dim=-1)
1736
+
1737
+
1738
+ @property
1739
+ def rel_weight_2(self):
1740
+ if self.config.weight_act == "sigmoid":
1741
+ return torch.sigmoid(self._rel_weight_2)
1742
+ elif self.config.weight_act == "softmax":
1743
+ return torch.softmax(self._rel_weight_2, dim=-1)
1744
+
1745
+
1746
+ def compute_block(self, distance, height, n_cntxt_layers=0):
1747
+ """Compute constituents from distance and height."""
1748
+
1749
+ if n_cntxt_layers>0:
1750
+ if n_cntxt_layers == 1:
1751
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler_1[0]
1752
+ elif n_cntxt_layers == 2:
1753
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler_2[0]
1754
+ else:
1755
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
1756
+
1757
+ gamma = torch.sigmoid(-beta_logits)
1758
+ ones = torch.ones_like(gamma)
1759
+
1760
+ block_mask_left = cummin(
1761
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1
1762
+ )
1763
+ block_mask_left = block_mask_left - F.pad(
1764
+ block_mask_left[:, :, :-1], (1, 0), value=0
1765
+ )
1766
+ block_mask_left.tril_(0)
1767
+
1768
+ block_mask_right = cummin(
1769
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1
1770
+ )
1771
+ block_mask_right = block_mask_right - F.pad(
1772
+ block_mask_right[:, :, 1:], (0, 1), value=0
1773
+ )
1774
+ block_mask_right.triu_(0)
1775
+
1776
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
1777
+ block = cumsum(block_mask_left).tril(0) + cumsum(
1778
+ block_mask_right, reverse=True
1779
+ ).triu(1)
1780
+
1781
+ return block_p, block
1782
+
1783
+ def compute_head(self, height, n_cntxt_layers=0):
1784
+ """Estimate head for each constituent."""
1785
+
1786
+ _, length = height.size()
1787
+ if n_cntxt_layers>0:
1788
+ if n_cntxt_layers == 1:
1789
+ head_logits = height * self.scaler_1[1]
1790
+ elif n_cntxt_layers == 2:
1791
+ head_logits = height * self.scaler_2[1]
1792
+ else:
1793
+ head_logits = height * self.scaler[1]
1794
+ index = torch.arange(length, device=height.device)
1795
+
1796
+ mask = (index[:, None, None] <= index[None, None, :]) * (
1797
+ index[None, None, :] <= index[None, :, None]
1798
+ )
1799
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
1800
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
1801
+
1802
+ head_p = torch.softmax(head_logits, dim=-1)
1803
+
1804
+ return head_p
1805
+
1806
+ def parse(self, x, embs=None, n_cntxt_layers=0):
1807
+ """Parse input sentence.
1808
+
1809
+ Args:
1810
+ x: input tokens (required).
1811
+ pos: position for each token (optional).
1812
+ Returns:
1813
+ distance: syntactic distance
1814
+ height: syntactic height
1815
+ """
1816
+
1817
+ mask = x != self.pad
1818
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
1819
+
1820
+ if embs is None:
1821
+ h = self.roberta.embeddings(x)
1822
+ else:
1823
+ h = embs
1824
+
1825
+ if n_cntxt_layers > 0:
1826
+ if n_cntxt_layers == 1:
1827
+ parser_layers = self.parser_layers_1
1828
+ height_ff = self.height_ff_1
1829
+ distance_ff = self.distance_ff_1
1830
+ elif n_cntxt_layers == 2:
1831
+ parser_layers = self.parser_layers_2
1832
+ height_ff = self.height_ff_2
1833
+ distance_ff = self.distance_ff_2
1834
+ for i in range(int(self.config.n_parser_layers/2)):
1835
+ h = h.masked_fill(~mask[:, :, None], 0)
1836
+ h = parser_layers[i](h)
1837
+
1838
+ height = height_ff(h).squeeze(-1)
1839
+ height.masked_fill_(~mask, -1e9)
1840
+
1841
+ distance = distance_ff(h).squeeze(-1)
1842
+ distance.masked_fill_(~mask_shifted, 1e9)
1843
+
1844
+ # Calbrating the distance and height to the same level
1845
+ length = distance.size(1)
1846
+ height_max = height[:, None, :].expand(-1, length, -1)
1847
+ height_max = torch.cummax(
1848
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9, dim=-1
1849
+ )[0].triu(0)
1850
+
1851
+ margin_left = torch.relu(
1852
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max
1853
+ )
1854
+ margin_right = torch.relu(distance[:, None, :] - height_max)
1855
+ margin = torch.where(
1856
+ margin_left > margin_right, margin_right, margin_left
1857
+ ).triu(0)
1858
+
1859
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
1860
+ margin.masked_fill_(~margin_mask, 0)
1861
+ margin = margin.max()
1862
+
1863
+ distance = distance - margin
1864
+ else:
1865
+ for i in range(self.config.n_parser_layers):
1866
+ h = h.masked_fill(~mask[:, :, None], 0)
1867
+ h = self.parser_layers[i](h)
1868
+
1869
+ height = self.height_ff(h).squeeze(-1)
1870
+ height.masked_fill_(~mask, -1e9)
1871
+
1872
+ distance = self.distance_ff(h).squeeze(-1)
1873
+ distance.masked_fill_(~mask_shifted, 1e9)
1874
+
1875
+ # Calbrating the distance and height to the same level
1876
+ length = distance.size(1)
1877
+ height_max = height[:, None, :].expand(-1, length, -1)
1878
+ height_max = torch.cummax(
1879
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9, dim=-1
1880
+ )[0].triu(0)
1881
+
1882
+ margin_left = torch.relu(
1883
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max
1884
+ )
1885
+ margin_right = torch.relu(distance[:, None, :] - height_max)
1886
+ margin = torch.where(
1887
+ margin_left > margin_right, margin_right, margin_left
1888
+ ).triu(0)
1889
+
1890
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
1891
+ margin.masked_fill_(~margin_mask, 0)
1892
+ margin = margin.max()
1893
+
1894
+ distance = distance - margin
1895
+
1896
+ return distance, height
1897
+
1898
+ def generate_mask(self, x, distance, height, n_cntxt_layers=0):
1899
+ """Compute head and cibling distribution for each token."""
1900
+
1901
+ bsz, length = x.size()
1902
+
1903
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
1904
+ eye = eye[None, :, :].expand((bsz, -1, -1))
1905
+
1906
+ block_p, block = self.compute_block(distance, height, n_cntxt_layers=n_cntxt_layers)
1907
+ head_p = self.compute_head(height, n_cntxt_layers=n_cntxt_layers)
1908
+ head = torch.einsum("blij,bijh->blh", block_p, head_p)
1909
+ head = head.masked_fill(eye, 0)
1910
+ child = head.transpose(1, 2)
1911
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
1912
+
1913
+ rel_list = []
1914
+ if "head" in self.config.relations:
1915
+ rel_list.append(head)
1916
+ if "child" in self.config.relations:
1917
+ rel_list.append(child)
1918
+ if "cibling" in self.config.relations:
1919
+ rel_list.append(cibling)
1920
+
1921
+ rel = torch.stack(rel_list, dim=1)
1922
+
1923
+ if n_cntxt_layers > 0:
1924
+ if n_cntxt_layers == 1:
1925
+ rel_weight = self.rel_weight_1
1926
+ elif n_cntxt_layers == 2:
1927
+ rel_weight = self.rel_weight_2
1928
+ else:
1929
+ rel_weight = self.rel_weight
1930
+
1931
+ dep = torch.einsum("lhr,brij->lbhij", rel_weight, rel)
1932
+
1933
+ if n_cntxt_layers == 1:
1934
+ num_layers = self.cntxt_layers_2.config.num_hidden_layers
1935
+ else:
1936
+ num_layers = self.roberta.config.num_hidden_layers
1937
+
1938
+ att_mask = dep.reshape(
1939
+ num_layers,
1940
+ bsz,
1941
+ self.config.num_attention_heads,
1942
+ length,
1943
+ length,
1944
+ )
1945
+
1946
+ return att_mask, cibling, head, block
1947
+
1948
+ def forward(
1949
+ self,
1950
+ input_ids: Optional[torch.LongTensor] = None,
1951
+ attention_mask: Optional[torch.FloatTensor] = None,
1952
+ token_type_ids: Optional[torch.LongTensor] = None,
1953
+ position_ids: Optional[torch.LongTensor] = None,
1954
+ head_mask: Optional[torch.FloatTensor] = None,
1955
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1956
+ labels: Optional[torch.LongTensor] = None,
1957
+ output_attentions: Optional[bool] = None,
1958
+ output_hidden_states: Optional[bool] = None,
1959
+ return_dict: Optional[bool] = None,
1960
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1961
+ r"""
1962
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1963
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1964
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1965
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1966
+ """
1967
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1968
+
1969
+ if self.config.n_cntxt_layers > 0:
1970
+ cntxt_outputs = self.cntxt_layers(
1971
+ input_ids,
1972
+ attention_mask=attention_mask,
1973
+ token_type_ids=token_type_ids,
1974
+ position_ids=position_ids,
1975
+ head_mask=head_mask,
1976
+ inputs_embeds=inputs_embeds,
1977
+ output_attentions=output_attentions,
1978
+ output_hidden_states=output_hidden_states,
1979
+ return_dict=return_dict)
1980
+
1981
+
1982
+ if self.config.n_cntxt_layers_2 > 0:
1983
+ distance_1, height_1 = self.parse(input_ids, cntxt_outputs[0], n_cntxt_layers=1)
1984
+ att_mask_1, _, _, _ = self.generate_mask(input_ids, distance_1, height_1, n_cntxt_layers=1)
1985
+
1986
+ cntxt_outputs_2 = self.cntxt_layers_2(
1987
+ input_ids,
1988
+ attention_mask=attention_mask,
1989
+ token_type_ids=token_type_ids,
1990
+ position_ids=position_ids,
1991
+ head_mask=head_mask,
1992
+ inputs_embeds=inputs_embeds,
1993
+ output_attentions=output_attentions,
1994
+ output_hidden_states=output_hidden_states,
1995
+ return_dict=return_dict,
1996
+ parser_att_mask=att_mask_1)
1997
+
1998
+ sequence_output = cntxt_outputs_2[0]
1999
+
2000
+ distance_2, height_2 = self.parse(input_ids, sequence_output[0], n_cntxt_layers=2)
2001
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance_2, height_2, n_cntxt_layers=2)
2002
+
2003
+ elif self.config.n_cntxt_layers > 0:
2004
+ distance, height = self.parse(input_ids, cntxt_outputs[0])
2005
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance, height)
2006
+ else:
2007
+ distance, height = self.parse(input_ids)
2008
+ att_mask, _, _, _ = self.generate_mask(input_ids, distance, height)
2009
+
2010
+ outputs = self.roberta(
2011
+ input_ids,
2012
+ attention_mask=attention_mask,
2013
+ token_type_ids=token_type_ids,
2014
+ position_ids=position_ids,
2015
+ head_mask=head_mask,
2016
+ inputs_embeds=inputs_embeds,
2017
+ output_attentions=output_attentions,
2018
+ output_hidden_states=output_hidden_states,
2019
+ return_dict=return_dict,
2020
+ parser_att_mask=att_mask,
2021
+ )
2022
+ sequence_output = outputs[0]
2023
+ logits = self.classifier(sequence_output)
2024
+
2025
+ loss = None
2026
+ if labels is not None:
2027
+ if self.config.problem_type is None:
2028
+ if self.num_labels == 1:
2029
+ self.config.problem_type = "regression"
2030
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
2031
+ self.config.problem_type = "single_label_classification"
2032
+ else:
2033
+ self.config.problem_type = "multi_label_classification"
2034
+
2035
+ if self.config.problem_type == "regression":
2036
+ loss_fct = MSELoss()
2037
+ if self.num_labels == 1:
2038
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
2039
+ else:
2040
+ loss = loss_fct(logits, labels)
2041
+ elif self.config.problem_type == "single_label_classification":
2042
+ loss_fct = CrossEntropyLoss()
2043
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2044
+ elif self.config.problem_type == "multi_label_classification":
2045
+ loss_fct = BCEWithLogitsLoss()
2046
+ loss = loss_fct(logits, labels)
2047
+
2048
+ if not return_dict:
2049
+ output = (logits,) + outputs[2:]
2050
+ return ((loss,) + output) if loss is not None else output
2051
+
2052
+ return SequenceClassifierOutput(
2053
+ loss=loss,
2054
+ logits=logits,
2055
+ hidden_states=outputs.hidden_states,
2056
+ attentions=outputs.attentions,
2057
+ )
2058
+
2059
+
2060
+ class RobertaClassificationHead(nn.Module):
2061
+ """Head for sentence-level classification tasks."""
2062
+
2063
+ def __init__(self, config):
2064
+ super().__init__()
2065
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
2066
+ classifier_dropout = (
2067
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
2068
+ )
2069
+ self.dropout = nn.Dropout(classifier_dropout)
2070
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
2071
+
2072
+ def forward(self, features, **kwargs):
2073
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
2074
+ x = self.dropout(x)
2075
+ x = self.dense(x)
2076
+ x = torch.tanh(x)
2077
+ x = self.dropout(x)
2078
+ x = self.out_proj(x)
2079
+ return x
2080
+
2081
+
2082
+ def create_position_ids_from_input_ids(
2083
+ input_ids, padding_idx, past_key_values_length=0
2084
+ ):
2085
+ """
2086
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
2087
+ are ignored. This is modified from fairseq's `utils.make_positions`.
2088
+
2089
+ Args:
2090
+ x: torch.Tensor x:
2091
+
2092
+ Returns: torch.Tensor
2093
+ """
2094
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
2095
+ mask = input_ids.ne(padding_idx).int()
2096
+ incremental_indices = (
2097
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
2098
+ ) * mask
2099
+ return incremental_indices.long() + padding_idx
2100
+
2101
+
2102
+ def cumprod(x, reverse=False, exclusive=False):
2103
+ """cumulative product."""
2104
+ if reverse:
2105
+ x = x.flip([-1])
2106
+
2107
+ if exclusive:
2108
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
2109
+
2110
+ cx = x.cumprod(-1)
2111
+
2112
+ if reverse:
2113
+ cx = cx.flip([-1])
2114
+ return cx
2115
+
2116
+
2117
+ def cumsum(x, reverse=False, exclusive=False):
2118
+ """cumulative sum."""
2119
+ bsz, _, length = x.size()
2120
+ device = x.device
2121
+ if reverse:
2122
+ if exclusive:
2123
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
2124
+ else:
2125
+ w = torch.ones([bsz, length, length], device=device).tril(0)
2126
+ cx = torch.bmm(x, w)
2127
+ else:
2128
+ if exclusive:
2129
+ w = torch.ones([bsz, length, length], device=device).triu(1)
2130
+ else:
2131
+ w = torch.ones([bsz, length, length], device=device).triu(0)
2132
+ cx = torch.bmm(x, w)
2133
+ return cx
2134
+
2135
+
2136
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
2137
+ """cumulative min."""
2138
+ if reverse:
2139
+ if exclusive:
2140
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
2141
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
2142
+ else:
2143
+ if exclusive:
2144
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
2145
+ x = x.cummin(-1)[0]
2146
+ return x
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:818c667dfd50dc5dd559e30d4f79a648b8e2334ccdcc2fc2cffbea6b50a10ff4
3
+ size 577194687