File size: 1,583 Bytes
6f408cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from .transformerutils import TransformerInterEncoder
from transformers import PreTrainedModel, AutoModel, LongformerConfig
from .configuration import ExtSummConfig
class LongformerSummarizer(PreTrainedModel):
config_class = ExtSummConfig
def __init__(self, config):
super().__init__(config)
self.longformer = AutoModel.from_config(LongformerConfig.from_pretrained("allenai/longformer-base-4096"))
self.input_size = config.input_size
self.interSentenceEncoder = TransformerInterEncoder(self.longformer.config.hidden_size, max_len=4096)
def forward(self, batch):
document_ids = batch["ids"].to(self.longformer.device)
clss_mask = batch["clss_mask"].to(self.longformer.device)
attn_mask = batch["attn_mask"].to(self.longformer.device)
global_attn_mask = batch["global_attn_mask"].to(self.longformer.device)
tokens_out, _ = self.longformer(input_ids=document_ids, attention_mask=attn_mask, global_attention_mask=global_attn_mask, return_dict=False)
out = []
logits_out = []
for i in range(len(tokens_out)): # Batch handling
clss_out = tokens_out[i][clss_mask[i], :]
sentences_scores, logits = self.interSentenceEncoder(clss_out)
padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device)
out.append( torch.cat((sentences_scores, padding)) )
logits_out.append( torch.cat((logits, padding)) )
return torch.stack(out), torch.stack(logits_out) |