File size: 2,396 Bytes
e842824 ba33264 b4da537 e842824 ba33264 1b85e66 ba33264 68013a5 e842824 ba33264 e842824 ba33264 e842824 ba33264 e842824 ba33264 26d4cd1 ba33264 11b09bf ba33264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import AutoModel, PreTrainedModel, BertConfig
import torch
class MultiLabelAttention(torch.nn.Module):
def __init__(self, D_in, num_labels):
super().__init__()
self.A = torch.nn.Parameter(torch.empty(D_in, num_labels))
torch.nn.init.uniform_(self.A, -0.1, 0.1)
def forward(self, x):
attention_weights = torch.nn.functional.softmax(
torch.tanh(torch.matmul(x, self.A)), dim=1
)
return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
class BertMesh(PreTrainedModel):
config_class = BertConfig
def __init__(
self,
config,
):
super().__init__(config=config)
self.config.auto_map = {"AutoModel": "model.BertMesh"}
self.pretrained_model = self.config.pretrained_model
self.num_labels = self.config.num_labels
self.hidden_size = getattr(self.config, "hidden_size", 512)
self.dropout = getattr(self.config, "dropout", 0.1)
self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
self.multilabel_attention_layer = MultiLabelAttention(
768, self.num_labels
) # num_labels, 768
self.linear_1 = torch.nn.Linear(768, self.hidden_size) # num_labels, 512
self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # num_labels, 1
self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
self.dropout_layer = torch.nn.Dropout(self.dropout)
def forward(self, input_ids, **kwargs):
if type(input_ids) is list:
# coming from tokenizer
input_ids = torch.tensor(input_ids)
if self.multilabel_attention:
hidden_states = self.bert(input_ids=input_ids)[0]
attention_outs = self.multilabel_attention_layer(hidden_states)
outs = torch.nn.functional.relu(self.linear_1(attention_outs))
outs = self.dropout_layer(outs)
outs = torch.sigmoid(self.linear_2(outs))
outs = torch.flatten(outs, start_dim=1)
else:
cls = self.bert(input_ids=input_ids)[1]
outs = torch.nn.functional.relu(self.linear_1(cls))
outs = self.dropout_layer(outs)
outs = torch.sigmoid(self.linear_out(outs))
return outs
|