File size: 1,101 Bytes
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from transformers import T5EncoderModel
from transformers.modeling_outputs import SequenceClassifierOutput


class IntentPredictModel(nn.Module):
    def __init__(self, pretrained_model_name_or_path, num_classes):
        super().__init__()
        self.backbone = T5EncoderModel.from_pretrained(pretrained_model_name_or_path)

        out_features = self.backbone.encoder.block[-1].layer[-1].DenseReluDense.wo.out_features  # t5-base 为 768,t5-large 为 1024
        self.fc1 = nn.Linear(out_features, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, input_ids, attention_mask):
        out_backbone = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        out_backbone = out_backbone.last_hidden_state  # [batch_size, seq_len, out_features]
        out_backbone = torch.mean(out_backbone, dim=1)  # [batch_size, out_features]
        out_fc1 = self.fc1(out_backbone)  # [batch_size, 128]
        out_fc2 = self.fc2(out_fc1)  # [batch_size, num_classes]

        return SequenceClassifierOutput(logits=out_fc2)