import torch from .transformerutils import TransformerInterEncoder from transformers import PreTrainedModel, AutoModel, BertConfig from .configuration import ExtSummConfig class BERTSummarizer(PreTrainedModel): config_class = ExtSummConfig def __init__(self, config): super().__init__(config) self.bert = AutoModel.from_config(BertConfig.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")) self.input_size = config.input_size self.encoder = TransformerInterEncoder(self.bert.config.hidden_size, max_len=512) def forward(self, batch): document_ids = batch["ids"].to(self.bert.device) segments_ids = batch["segments_ids"].to(self.bert.device) clss_mask = batch["clss_mask"].to(self.bert.device) attn_mask = batch["attn_mask"].to(self.bert.device) tokens_out, _ = self.bert(input_ids=document_ids, token_type_ids=segments_ids, attention_mask=attn_mask, return_dict=False) out = [] logits_out = [] for i in range(len(tokens_out)): # Batch handling clss_out = tokens_out[i][clss_mask[i], :] sentences_scores, logits = self.encoder(clss_out) padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device) out.append( torch.cat((sentences_scores, padding)) ) logits_out.append( torch.cat((logits, padding)) ) return torch.stack(out), torch.stack(logits_out)