NotXia's picture
Add model
04f3f18 unverified
raw
history blame
4.59 kB
from transformers import Pipeline
import torch
from .utilities import padToSize
from .summary import select, splitDocument
"""
Generates the segments ids for BERT
"""
def generateSegmentIds(doc_ids, tokenizer):
# Alternating 0s and 1s
segments_ids = [0] * len(doc_ids)
curr_segment = 0
for i, token in enumerate(doc_ids):
segments_ids[i] = curr_segment
if token == tokenizer.vocab["[SEP]"]:
curr_segment = 1 - curr_segment
return segments_ids
class ExtSummPipeline(Pipeline):
"""
Extractive summarization pipeline
Inputs
------
inputs : dict
'sentences' : list[str]
Sentences of the document
strategy : str
Strategy to summarize the document:
- 'length': summary with a maximum length (strategy_args is the maximum length).
- 'count': summary with the given number of sentences (strategy_args is the number of sentences).
- 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]).
- 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score).
strategy_args : any
Parameters of the strategy.
Outputs
-------
selected_sents : list[str]
List of the selected sentences
selected_idxs : list[int]
List of the indexes of the selected sentences in the original input
"""
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs):
raise ValueError("`strategy` and `strategy_args` have to be both set")
if "strategy" in kwargs:
postprocess_kwargs["strategy"] = kwargs["strategy"]
if "strategy_args" in kwargs:
postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
return {}, {}, postprocess_kwargs
def preprocess(self, inputs):
sentences = inputs["sentences"]
# Tokenization and chunking
doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) )
doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token]
doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size)
# Batch preparation
batch = {
"ids": [],
"segments_ids": [],
"clss_mask": [],
"attn_mask": [],
}
for chunk_tokens in doc_chunks:
doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens)
segment_ids = generateSegmentIds(doc_ids, self.tokenizer)
clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids]
attn_mask = [1 for _ in range(len(doc_ids))]
batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) )
batch["segments_ids"].append( padToSize(segment_ids, self.model.config.input_size, 0) )
batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) )
batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) )
batch["ids"] = torch.as_tensor(batch["ids"])
batch["segments_ids"] = torch.as_tensor(batch["segments_ids"])
batch["clss_mask"] = torch.as_tensor(batch["clss_mask"])
batch["attn_mask"] = torch.as_tensor(batch["attn_mask"])
return { "inputs": batch, "sentences": sentences }
def _forward(self, args):
batch = args["inputs"]
sentences = args["sentences"]
out_predictions = torch.as_tensor([]).to(self.device)
self.model.eval()
with torch.no_grad():
batch_preds, _ = self.model(batch)
for i, clss_mask in enumerate(batch["clss_mask"]):
out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)]))
return { "predictions": out_predictions, "sentences": sentences }
def postprocess(self, args, strategy: str="count", strategy_args=3):
predictions = args["predictions"]
sentences = args["sentences"]
return select(sentences, predictions, strategy, strategy_args)