NotXia's picture
Add model
6f408cf unverified
raw
history blame
4.28 kB
from transformers import Pipeline
import torch
from .utilities import padToSize
from .summary import select, splitDocument
class ExtSummPipeline(Pipeline):
"""
Extractive summarization pipeline
Inputs
------
inputs : dict
'sentences' : list[str]
Sentences of the document
strategy : str
Strategy to summarize the document:
- 'length': summary with a maximum length (strategy_args is the maximum length).
- 'count': summary with the given number of sentences (strategy_args is the number of sentences).
- 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]).
- 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score).
strategy_args : any
Parameters of the strategy.
Outputs
-------
selected_sents : list[str]
List of the selected sentences
selected_idxs : list[int]
List of the indexes of the selected sentences in the original input
"""
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs):
raise ValueError("`strategy` and `strategy_args` have to be both set")
if "strategy" in kwargs:
postprocess_kwargs["strategy"] = kwargs["strategy"]
if "strategy_args" in kwargs:
postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
return {}, {}, postprocess_kwargs
def preprocess(self, inputs):
sentences = inputs["sentences"]
# Tokenization and chunking
doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) )
doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token]
doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size)
# Batch preparation
batch = {
"ids": [],
"clss_mask": [],
"attn_mask": [],
"global_attn_mask": [],
}
for chunk_tokens in doc_chunks:
doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens)
clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids]
attn_mask = [1 for _ in range(len(doc_ids))]
global_attn_mask = [1 if token == self.tokenizer.cls_token_id else 0 for token in doc_ids]
batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) )
batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) )
batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) )
batch["global_attn_mask"].append( padToSize(global_attn_mask, self.model.config.input_size, 0) )
batch["ids"] = torch.as_tensor(batch["ids"])
batch["clss_mask"] = torch.as_tensor(batch["clss_mask"])
batch["attn_mask"] = torch.as_tensor(batch["attn_mask"])
batch["global_attn_mask"] = torch.as_tensor(batch["global_attn_mask"])
return { "inputs": batch, "sentences": sentences }
def _forward(self, args):
batch = args["inputs"]
sentences = args["sentences"]
out_predictions = torch.as_tensor([]).to(self.device)
self.model.eval()
with torch.no_grad():
batch_preds, _ = self.model(batch)
for i, clss_mask in enumerate(batch["clss_mask"]):
out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)]))
return { "predictions": out_predictions, "sentences": sentences }
def postprocess(self, args, strategy: str="count", strategy_args=3):
predictions = args["predictions"]
sentences = args["sentences"]
return select(sentences, predictions, strategy, strategy_args)