|
from transformers import Pipeline |
|
import torch |
|
from .utilities import padToSize |
|
from .summary import select, splitDocument |
|
|
|
|
|
""" |
|
Generates the segments ids for BERT |
|
""" |
|
def generateSegmentIds(doc_ids, tokenizer): |
|
|
|
segments_ids = [0] * len(doc_ids) |
|
curr_segment = 0 |
|
|
|
for i, token in enumerate(doc_ids): |
|
segments_ids[i] = curr_segment |
|
if token == tokenizer.vocab["[SEP]"]: |
|
curr_segment = 1 - curr_segment |
|
|
|
return segments_ids |
|
|
|
|
|
class ExtSummPipeline(Pipeline): |
|
""" |
|
Extractive summarization pipeline |
|
|
|
Inputs |
|
------ |
|
inputs : dict |
|
'sentences' : list[str] |
|
Sentences of the document |
|
|
|
strategy : str |
|
Strategy to summarize the document: |
|
- 'length': summary with a maximum length (strategy_args is the maximum length). |
|
- 'count': summary with the given number of sentences (strategy_args is the number of sentences). |
|
- 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]). |
|
- 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score). |
|
|
|
strategy_args : any |
|
Parameters of the strategy. |
|
|
|
Outputs |
|
------- |
|
selected_sents : list[str] |
|
List of the selected sentences |
|
|
|
selected_idxs : list[int] |
|
List of the indexes of the selected sentences in the original input |
|
""" |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
postprocess_kwargs = {} |
|
|
|
if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs): |
|
raise ValueError("`strategy` and `strategy_args` have to be both set") |
|
if "strategy" in kwargs: |
|
postprocess_kwargs["strategy"] = kwargs["strategy"] |
|
if "strategy_args" in kwargs: |
|
postprocess_kwargs["strategy_args"] = kwargs["strategy_args"] |
|
|
|
return {}, {}, postprocess_kwargs |
|
|
|
|
|
def preprocess(self, inputs): |
|
sentences = inputs["sentences"] |
|
|
|
|
|
doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) ) |
|
doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token] |
|
doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size) |
|
|
|
|
|
batch = { |
|
"ids": [], |
|
"segments_ids": [], |
|
"clss_mask": [], |
|
"attn_mask": [], |
|
} |
|
for chunk_tokens in doc_chunks: |
|
doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens) |
|
segment_ids = generateSegmentIds(doc_ids, self.tokenizer) |
|
clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids] |
|
attn_mask = [1 for _ in range(len(doc_ids))] |
|
|
|
batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) ) |
|
batch["segments_ids"].append( padToSize(segment_ids, self.model.config.input_size, 0) ) |
|
batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) ) |
|
batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) ) |
|
|
|
batch["ids"] = torch.as_tensor(batch["ids"]) |
|
batch["segments_ids"] = torch.as_tensor(batch["segments_ids"]) |
|
batch["clss_mask"] = torch.as_tensor(batch["clss_mask"]) |
|
batch["attn_mask"] = torch.as_tensor(batch["attn_mask"]) |
|
return { "inputs": batch, "sentences": sentences } |
|
|
|
|
|
def _forward(self, args): |
|
batch = args["inputs"] |
|
sentences = args["sentences"] |
|
out_predictions = torch.as_tensor([]).to(self.device) |
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
batch_preds, _ = self.model(batch) |
|
for i, clss_mask in enumerate(batch["clss_mask"]): |
|
out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)])) |
|
|
|
return { "predictions": out_predictions, "sentences": sentences } |
|
|
|
|
|
def postprocess(self, args, strategy: str="count", strategy_args=3): |
|
predictions = args["predictions"] |
|
sentences = args["sentences"] |
|
return select(sentences, predictions, strategy, strategy_args) |