NotXia's picture
Add model
04f3f18 unverified
raw
history blame contribute delete
No virus
1.51 kB
import torch
from .transformerutils import TransformerInterEncoder
from transformers import PreTrainedModel, AutoModel, BertConfig
from .configuration import ExtSummConfig
class BERTSummarizer(PreTrainedModel):
config_class = ExtSummConfig
def __init__(self, config):
super().__init__(config)
self.bert = AutoModel.from_config(BertConfig.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"))
self.input_size = config.input_size
self.encoder = TransformerInterEncoder(self.bert.config.hidden_size, max_len=512)
def forward(self, batch):
document_ids = batch["ids"].to(self.bert.device)
segments_ids = batch["segments_ids"].to(self.bert.device)
clss_mask = batch["clss_mask"].to(self.bert.device)
attn_mask = batch["attn_mask"].to(self.bert.device)
tokens_out, _ = self.bert(input_ids=document_ids, token_type_ids=segments_ids, attention_mask=attn_mask, return_dict=False)
out = []
logits_out = []
for i in range(len(tokens_out)): # Batch handling
clss_out = tokens_out[i][clss_mask[i], :]
sentences_scores, logits = self.encoder(clss_out)
padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device)
out.append( torch.cat((sentences_scores, padding)) )
logits_out.append( torch.cat((logits, padding)) )
return torch.stack(out), torch.stack(logits_out)