import torch.nn as nn | |
from bert import BERT | |
class BERTForClassification(nn.Module): | |
""" | |
Progress Classifier Model | |
""" | |
def __init__(self, bert: BERT, vocab_size, n_labels): | |
""" | |
:param bert: BERT model which should be trained | |
:param vocab_size: total vocab size for masked_lm | |
""" | |
super().__init__() | |
self.bert = bert | |
self.linear = nn.Linear(self.bert.hidden, n_labels) | |
# self.softmax = nn.LogSoftmax(dim=-1) | |
def forward(self, x, segment_label): | |
x = self.bert(x, segment_label) | |
return x, self.linear(x[:, 0]) |