File size: 1,509 Bytes
04f3f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from .transformerutils import TransformerInterEncoder
from transformers import PreTrainedModel, AutoModel, BertConfig
from .configuration import ExtSummConfig



class BERTSummarizer(PreTrainedModel):
    config_class = ExtSummConfig

    def __init__(self, config):
        super().__init__(config)
        self.bert = AutoModel.from_config(BertConfig.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"))
        self.input_size = config.input_size
        self.encoder = TransformerInterEncoder(self.bert.config.hidden_size, max_len=512)

    def forward(self, batch):
        document_ids = batch["ids"].to(self.bert.device)
        segments_ids = batch["segments_ids"].to(self.bert.device)
        clss_mask = batch["clss_mask"].to(self.bert.device)
        attn_mask = batch["attn_mask"].to(self.bert.device)

        tokens_out, _ = self.bert(input_ids=document_ids, token_type_ids=segments_ids, attention_mask=attn_mask, return_dict=False)
        out = []
        logits_out = []

        for i in range(len(tokens_out)): # Batch handling
            clss_out = tokens_out[i][clss_mask[i], :]
            sentences_scores, logits = self.encoder(clss_out)
            padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device)
            
            out.append( torch.cat((sentences_scores, padding)) )
            logits_out.append( torch.cat((logits, padding)) )

        return torch.stack(out), torch.stack(logits_out)