File size: 614 Bytes
6a34fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn

from bert import BERT


class BERTForClassification(nn.Module):
    """
        Progress Classifier Model
    """

    def __init__(self, bert: BERT, vocab_size, n_labels):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.linear = nn.Linear(self.bert.hidden, n_labels)
        # self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return x, self.linear(x[:, 0])