|
import torch |
|
from .transformerutils import TransformerInterEncoder |
|
from transformers import PreTrainedModel, AutoModel, LongformerConfig |
|
from .configuration import ExtSummConfig |
|
|
|
|
|
|
|
class LongformerSummarizer(PreTrainedModel): |
|
config_class = ExtSummConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.longformer = AutoModel.from_config(LongformerConfig.from_pretrained("allenai/longformer-base-4096")) |
|
self.input_size = config.input_size |
|
self.interSentenceEncoder = TransformerInterEncoder(self.longformer.config.hidden_size, max_len=4096) |
|
|
|
|
|
def forward(self, batch): |
|
document_ids = batch["ids"].to(self.longformer.device) |
|
clss_mask = batch["clss_mask"].to(self.longformer.device) |
|
attn_mask = batch["attn_mask"].to(self.longformer.device) |
|
global_attn_mask = batch["global_attn_mask"].to(self.longformer.device) |
|
|
|
tokens_out, _ = self.longformer(input_ids=document_ids, attention_mask=attn_mask, global_attention_mask=global_attn_mask, return_dict=False) |
|
out = [] |
|
logits_out = [] |
|
|
|
for i in range(len(tokens_out)): |
|
clss_out = tokens_out[i][clss_mask[i], :] |
|
sentences_scores, logits = self.interSentenceEncoder(clss_out) |
|
padding = torch.zeros(self.input_size - sentences_scores.shape[0]).to(sentences_scores.device) |
|
|
|
out.append( torch.cat((sentences_scores, padding)) ) |
|
logits_out.append( torch.cat((logits, padding)) ) |
|
|
|
return torch.stack(out), torch.stack(logits_out) |