Transformers
PyTorch
code
custom_code
Inference Endpoints
codesage commited on
Commit
34e872f
1 Parent(s): 006daa1

Create modeling_codesage.py

Browse files
Files changed (1) hide show
  1. modeling_codesage.py +358 -0
modeling_codesage.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4
+
5
+ import math
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_utils import Conv1D, PreTrainedModel
12
+ from transformers.utils import logging
13
+ from .config_codesage import CodeSageConfig
14
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+ CODESAGE_PRETRAINED_MODEL_ARCHIVE_LIST = [
19
+ "codesage/codesage-small",
20
+ "codesage/codesage-base",
21
+ "codesage/codesage-large",
22
+ # See all CodeSage models at https://huggingface.co/models?filter=codesage
23
+ ]
24
+
25
+
26
+ class CodeSageAttention(nn.Module):
27
+ def __init__(self, config):
28
+ super().__init__()
29
+
30
+ self.hidden_size = config.hidden_size
31
+ self.num_heads = config.num_attention_heads
32
+ self.head_dim = config.hidden_size // self.num_heads
33
+ if self.head_dim * self.num_heads != config.hidden_size:
34
+ raise ValueError(
35
+ f"`hidden_size` must be divisible by num_heads "
36
+ f"(got `hidden_size`: {config.hidden_size} and `num_heads`: {self.num_heads})."
37
+ )
38
+
39
+ self.c_attn = Conv1D(3 * self.hidden_size, self.hidden_size)
40
+ self.c_proj = Conv1D(self.hidden_size, self.hidden_size)
41
+
42
+ self.attention_dropout = nn.Dropout(config.attention_dropout_prob)
43
+ self.residual_dropout = nn.Dropout(config.residual_dropout_prob)
44
+
45
+ def attn(self, query, key, value, attention_mask=None, head_mask=None):
46
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
47
+ attn_weights = attn_weights / math.sqrt(self.head_dim)
48
+ if attention_mask is not None:
49
+ attn_weights = attn_weights + attention_mask
50
+
51
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
52
+ attn_weights = self.attention_dropout(attn_weights)
53
+ if head_mask is not None:
54
+ attn_weights = attn_weights * head_mask
55
+
56
+ attn_output = torch.matmul(attn_weights, value)
57
+ return attn_output, attn_weights
58
+
59
+ def split_heads(self, tensor, num_heads, attn_head_size):
60
+ """
61
+ Splits hidden_size dim into attn_head_size and num_heads
62
+ """
63
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
64
+ tensor = tensor.view(*new_shape)
65
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
66
+
67
+ def merge_heads(self, tensor, num_heads, attn_head_size):
68
+ """
69
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
70
+ """
71
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
72
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
73
+ return tensor.view(new_shape)
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states,
78
+ attention_mask=None,
79
+ head_mask=None,
80
+ output_attentions=False,
81
+ ):
82
+ query, key, value = self.c_attn(hidden_states).split(self.hidden_size, dim=2)
83
+ query = self.split_heads(query, self.num_heads, self.head_dim)
84
+ key = self.split_heads(key, self.num_heads, self.head_dim)
85
+ value = self.split_heads(value, self.num_heads, self.head_dim)
86
+
87
+ attn_output, attn_weights = self.attn(query, key, value, attention_mask, head_mask)
88
+
89
+ attn_output = self.merge_heads(attn_output, self.num_heads, self.head_dim)
90
+ attn_output = self.c_proj(attn_output)
91
+ attn_output = self.residual_dropout(attn_output)
92
+
93
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
94
+ return outputs # a, present, (attentions)
95
+
96
+
97
+ class CodeSageMLP(nn.Module):
98
+ def __init__(self, intermediate_size, config):
99
+ super().__init__()
100
+
101
+ self.c_fc = Conv1D(intermediate_size, config.hidden_size)
102
+ self.act = ACT2FN[config.activation_function]
103
+ self.c_proj = Conv1D(config.hidden_size, intermediate_size)
104
+ self.dropout = nn.Dropout(config.residual_dropout_prob)
105
+
106
+ def forward(self, hidden_states):
107
+ hidden_states = self.c_fc(hidden_states)
108
+ hidden_states = self.act(hidden_states)
109
+ hidden_states = self.c_proj(hidden_states)
110
+ hidden_states = self.dropout(hidden_states)
111
+ return hidden_states
112
+
113
+
114
+ class CodeSageBlock(nn.Module):
115
+ def __init__(self, config):
116
+ super().__init__()
117
+ hidden_size = config.hidden_size
118
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
119
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
120
+ self.attn = CodeSageAttention(config)
121
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
122
+ self.mlp = CodeSageMLP(inner_dim, config)
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states,
127
+ attention_mask=None,
128
+ head_mask=None,
129
+ output_attentions=False,
130
+ ):
131
+ residual = hidden_states
132
+ hidden_states = self.ln_1(hidden_states)
133
+ attn_outputs = self.attn(
134
+ hidden_states,
135
+ attention_mask=attention_mask,
136
+ head_mask=head_mask,
137
+ output_attentions=output_attentions
138
+ )
139
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
140
+ outputs = attn_outputs[1:]
141
+ hidden_states = attn_output + residual
142
+
143
+ residual = hidden_states
144
+ hidden_states = self.ln_2(hidden_states)
145
+ feed_forward_hidden_states = self.mlp(hidden_states)
146
+ hidden_states = residual + feed_forward_hidden_states
147
+
148
+ outputs = (hidden_states,) + outputs[1:]
149
+ return outputs # hidden_states, present, (attentions)
150
+
151
+
152
+ class CodeSagePreTrainedModel(PreTrainedModel):
153
+ config_class = CodeSageConfig
154
+
155
+ def _init_weights(self, module):
156
+ """Initialize the weights."""
157
+ if isinstance(module, (nn.Linear, Conv1D)):
158
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
159
+ if module.bias is not None:
160
+ module.bias.data.zero_()
161
+ elif isinstance(module, nn.Embedding):
162
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
163
+ if module.padding_idx is not None:
164
+ module.weight.data[module.padding_idx].zero_()
165
+ elif isinstance(module, nn.LayerNorm):
166
+ module.bias.data.zero_()
167
+ module.weight.data.fill_(1.0)
168
+
169
+
170
+ class CodeSageModel(CodeSagePreTrainedModel):
171
+ def __init__(self, config):
172
+ super().__init__(config)
173
+
174
+ self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
175
+ self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
176
+
177
+ self.drop = nn.Dropout(config.embedding_dropout_prob)
178
+ self.h = nn.ModuleList([CodeSageBlock(config) for _ in range(config.num_hidden_layers)])
179
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
180
+
181
+ self.init_weights()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.wte
185
+
186
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
187
+ self.wte = new_embeddings
188
+
189
+ def forward(
190
+ self,
191
+ input_ids=None,
192
+ attention_mask=None,
193
+ position_ids=None,
194
+ head_mask=None,
195
+ inputs_embeds=None,
196
+ output_attentions=None,
197
+ output_hidden_states=None,
198
+ return_dict=None
199
+ ):
200
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
201
+ output_hidden_states = (
202
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ )
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ if input_ids is not None and inputs_embeds is not None:
207
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
208
+ if input_ids is not None:
209
+ input_shape = input_ids.size()
210
+ elif inputs_embeds is not None:
211
+ input_shape = inputs_embeds.size()[:-1]
212
+ else:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
216
+ if position_ids is None:
217
+ position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
218
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
219
+ else:
220
+ position_ids = position_ids.view(-1, input_shape[-1])
221
+
222
+ extended_attention_mask = None
223
+ if attention_mask is not None:
224
+ assert attention_mask.dim() == 2
225
+ extended_attention_mask = attention_mask[:, None, None, :]
226
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
227
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
228
+
229
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
230
+ if inputs_embeds is None:
231
+ inputs_embeds = self.wte(input_ids)
232
+
233
+ position_embeds = self.wpe(position_ids)
234
+ hidden_states = inputs_embeds + position_embeds
235
+
236
+ hidden_states = self.drop(hidden_states)
237
+ output_shape = input_shape + (hidden_states.size(-1),)
238
+
239
+ all_self_attentions = () if output_attentions else None
240
+ all_hidden_states = () if output_hidden_states else None
241
+ for i, block in enumerate(self.h):
242
+ if output_hidden_states:
243
+ all_hidden_states = all_hidden_states + (hidden_states,)
244
+
245
+ outputs = block(
246
+ hidden_states,
247
+ attention_mask=extended_attention_mask,
248
+ head_mask=head_mask[i],
249
+ output_attentions=output_attentions,
250
+ )
251
+
252
+ hidden_states = outputs[0]
253
+ if output_attentions:
254
+ all_self_attentions = all_self_attentions + (outputs[1],)
255
+
256
+ hidden_states = self.ln_f(hidden_states)
257
+ hidden_states = hidden_states.view(*output_shape)
258
+ if output_hidden_states:
259
+ all_hidden_states = all_hidden_states + (hidden_states,)
260
+
261
+ pooled_output = None # max-pooled output
262
+ if attention_mask is not None:
263
+ pooled_output = (hidden_states * attention_mask[:, :, None]).sum(1) / attention_mask.sum(1)[:, None]
264
+
265
+ if not return_dict:
266
+ return tuple(
267
+ v
268
+ for v in [hidden_states, pooled_output, all_hidden_states, all_self_attentions]
269
+ if v is not None
270
+ )
271
+
272
+ return BaseModelOutputWithPooling(
273
+ last_hidden_state=hidden_states,
274
+ pooler_output=pooled_output,
275
+ hidden_states=all_hidden_states,
276
+ attentions=all_self_attentions
277
+ )
278
+
279
+
280
+ class CodeSageForSequenceClassification(CodeSagePreTrainedModel):
281
+ def __init__(self, config):
282
+ super().__init__(config)
283
+ self.num_labels = config.num_labels
284
+ self.config = config
285
+
286
+ self.encoder = CodeSageModel(config)
287
+ classifier_dropout = (
288
+ config.classifier_dropout if config.classifier_dropout is not None else config.residual_dropout_prob
289
+ )
290
+ self.dropout = nn.Dropout(classifier_dropout)
291
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
292
+
293
+ # Initialize weights and apply final processing
294
+ self.post_init()
295
+
296
+ def forward(
297
+ self,
298
+ input_ids=None,
299
+ attention_mask=None,
300
+ position_ids=None,
301
+ head_mask=None,
302
+ inputs_embeds=None,
303
+ labels=None,
304
+ output_attentions=None,
305
+ output_hidden_states=None,
306
+ return_dict=None,
307
+ ):
308
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
309
+ assert attention_mask is not None, "attention_mask is needed to perform max-pooling"
310
+
311
+ outputs = self.encoder(
312
+ input_ids,
313
+ attention_mask=attention_mask,
314
+ position_ids=position_ids,
315
+ head_mask=head_mask,
316
+ inputs_embeds=inputs_embeds,
317
+ output_attentions=output_attentions,
318
+ output_hidden_states=output_hidden_states,
319
+ return_dict=return_dict,
320
+ )
321
+
322
+ pooled_output = outputs[1]
323
+ pooled_output = self.dropout(pooled_output)
324
+ logits = self.classifier(pooled_output)
325
+
326
+ loss = None
327
+ if labels is not None:
328
+ if self.config.problem_type is None:
329
+ if self.num_labels == 1:
330
+ self.config.problem_type = "regression"
331
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
332
+ self.config.problem_type = "single_label_classification"
333
+ else:
334
+ self.config.problem_type = "multi_label_classification"
335
+
336
+ if self.config.problem_type == "regression":
337
+ loss_fct = MSELoss()
338
+ if self.num_labels == 1:
339
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
340
+ else:
341
+ loss = loss_fct(logits, labels)
342
+ elif self.config.problem_type == "single_label_classification":
343
+ loss_fct = CrossEntropyLoss()
344
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
345
+ elif self.config.problem_type == "multi_label_classification":
346
+ loss_fct = BCEWithLogitsLoss()
347
+ loss = loss_fct(logits, labels)
348
+
349
+ if not return_dict:
350
+ output = (logits,) + outputs[2:]
351
+ return ((loss,) + output) if loss is not None else output
352
+
353
+ return SequenceClassifierOutput(
354
+ loss=loss,
355
+ logits=logits,
356
+ hidden_states=outputs.hidden_states,
357
+ attentions=outputs.attentions,
358
+ )