from transformers import Pipeline import torch from .utilities import padToSize from .summary import select, splitDocument class ExtSummPipeline(Pipeline): """ Extractive summarization pipeline Inputs ------ inputs : dict 'sentences' : list[str] Sentences of the document strategy : str Strategy to summarize the document: - 'length': summary with a maximum length (strategy_args is the maximum length). - 'count': summary with the given number of sentences (strategy_args is the number of sentences). - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]). - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score). strategy_args : any Parameters of the strategy. Outputs ------- selected_sents : list[str] List of the selected sentences selected_idxs : list[int] List of the indexes of the selected sentences in the original input """ def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs): raise ValueError("`strategy` and `strategy_args` have to be both set") if "strategy" in kwargs: postprocess_kwargs["strategy"] = kwargs["strategy"] if "strategy_args" in kwargs: postprocess_kwargs["strategy_args"] = kwargs["strategy_args"] return {}, {}, postprocess_kwargs def preprocess(self, inputs): sentences = inputs["sentences"] # Tokenization and chunking doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) ) doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token] doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size) # Batch preparation batch = { "ids": [], "clss_mask": [], "attn_mask": [], "global_attn_mask": [], } for chunk_tokens in doc_chunks: doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens) clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids] attn_mask = [1 for _ in range(len(doc_ids))] global_attn_mask = [1 if token == self.tokenizer.cls_token_id else 0 for token in doc_ids] batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) ) batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) ) batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) ) batch["global_attn_mask"].append( padToSize(global_attn_mask, self.model.config.input_size, 0) ) batch["ids"] = torch.as_tensor(batch["ids"]) batch["clss_mask"] = torch.as_tensor(batch["clss_mask"]) batch["attn_mask"] = torch.as_tensor(batch["attn_mask"]) batch["global_attn_mask"] = torch.as_tensor(batch["global_attn_mask"]) return { "inputs": batch, "sentences": sentences } def _forward(self, args): batch = args["inputs"] sentences = args["sentences"] out_predictions = torch.as_tensor([]).to(self.device) self.model.eval() with torch.no_grad(): batch_preds, _ = self.model(batch) for i, clss_mask in enumerate(batch["clss_mask"]): out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)])) return { "predictions": out_predictions, "sentences": sentences } def postprocess(self, args, strategy: str="count", strategy_args=3): predictions = args["predictions"] sentences = args["sentences"] return select(sentences, predictions, strategy, strategy_args)