File size: 2,396 Bytes
e842824
ba33264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4da537
e842824
 
ba33264
 
1b85e66
ba33264
68013a5
e842824
 
 
 
 
 
 
 
ba33264
e842824
ba33264
e842824
 
 
ba33264
 
e842824
 
 
 
ba33264
26d4cd1
ba33264
 
 
 
 
 
11b09bf
ba33264
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import AutoModel, PreTrainedModel, BertConfig
import torch


class MultiLabelAttention(torch.nn.Module):
    def __init__(self, D_in, num_labels):
        super().__init__()
        self.A = torch.nn.Parameter(torch.empty(D_in, num_labels))
        torch.nn.init.uniform_(self.A, -0.1, 0.1)

    def forward(self, x):
        attention_weights = torch.nn.functional.softmax(
            torch.tanh(torch.matmul(x, self.A)), dim=1
        )
        return torch.matmul(torch.transpose(attention_weights, 2, 1), x)


class BertMesh(PreTrainedModel):
    config_class = BertConfig

    def __init__(
        self,
        config,
    ):
        super().__init__(config=config)
        self.config.auto_map = {"AutoModel": "model.BertMesh"}
        self.pretrained_model = self.config.pretrained_model
        self.num_labels = self.config.num_labels
        self.hidden_size = getattr(self.config, "hidden_size", 512)
        self.dropout = getattr(self.config, "dropout", 0.1)
        self.multilabel_attention = getattr(self.config, "multilabel_attention", False)

        self.bert = AutoModel.from_pretrained(self.pretrained_model)  # 768
        self.multilabel_attention_layer = MultiLabelAttention(
            768, self.num_labels
        )  # num_labels, 768
        self.linear_1 = torch.nn.Linear(768, self.hidden_size)  # num_labels, 512
        self.linear_2 = torch.nn.Linear(self.hidden_size, 1)  # num_labels, 1
        self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
        self.dropout_layer = torch.nn.Dropout(self.dropout)

    def forward(self, input_ids, **kwargs):
        if type(input_ids) is list:
            # coming from tokenizer
            input_ids = torch.tensor(input_ids)
        if self.multilabel_attention:
            hidden_states = self.bert(input_ids=input_ids)[0]
            attention_outs = self.multilabel_attention_layer(hidden_states)
            outs = torch.nn.functional.relu(self.linear_1(attention_outs))
            outs = self.dropout_layer(outs)
            outs = torch.sigmoid(self.linear_2(outs))
            outs = torch.flatten(outs, start_dim=1)
        else:
            cls = self.bert(input_ids=input_ids)[1]
            outs = torch.nn.functional.relu(self.linear_1(cls))
            outs = self.dropout_layer(outs)
            outs = torch.sigmoid(self.linear_out(outs))
        return outs