File size: 1,533 Bytes
1f878a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
import torch.nn as nn
from transformers import (
RobertaModel,
RobertaForQuestionAnswering,
)
class SpanPredictionHead(nn.Module):
"""Head for span prediction tasks.
Can be viewed as a 2-class output layer that is applied to every position.
"""
def __init__(self, input_dim, inner_dim, num_classes, pooler_dropout):
assert num_classes == 2
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features # take features across ALL positions
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x # B x T x C, but softmax should be taken over T
class RobertaForPororoMRC(RobertaForQuestionAnswering):
def __init__(self, config):
# Initialize on RobertaPreTrainedModel
super(RobertaForQuestionAnswering, self).__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.qa_outputs = SpanPredictionHead(
input_dim=config.hidden_size,
inner_dim=config.span_head_inner_dim,
num_classes=config.num_labels,
pooler_dropout=config.span_head_dropout,
)
self.init_weights() |