File size: 4,603 Bytes
6f408cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d735c8
 
6f408cf
 
 
 
 
 
 
6d735c8
 
6f408cf
 
 
 
 
 
 
 
 
 
 
 
6d735c8
 
6f408cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d735c8
6f408cf
 
6d735c8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from transformers import Pipeline
import torch
from .utilities import padToSize
from .summary import select, splitDocument



class ExtSummPipeline(Pipeline):
    """
        Extractive summarization pipeline
        
        Inputs
        ------
            inputs : dict
                'sentences' : list[str]
                    Sentences of the document

            strategy : str
                Strategy to summarize the document:
                - 'length': summary with a maximum length (strategy_args is the maximum length).
                - 'count': summary with the given number of sentences (strategy_args is the number of sentences).
                - 'ratio': summary proportional to the length of the document (strategy_args is the ratio [0, 1]).
                - 'threshold': summary only with sentences with a score higher than a given value (strategy_args is the minimum score).

            strategy_args : any
                Parameters of the strategy.

            out_scores : bool
                If True, the score for each sentence is returned.
        Outputs
        -------
            selected_sents : list[str]
                List of the selected sentences

            selected_idxs : list[int]
                List of the indexes of the selected sentences in the original input

            sents_scores : Tensor (optional)
    """

    
    def _sanitize_parameters(self, **kwargs):
        postprocess_kwargs = {}
        
        if ("strategy" in kwargs and "strategy_args" not in kwargs) or ("strategy" not in kwargs and "strategy_args" in kwargs):
            raise ValueError("`strategy` and `strategy_args` have to be both set")
        if "strategy" in kwargs:
            postprocess_kwargs["strategy"] = kwargs["strategy"]
        if "strategy_args" in kwargs:
            postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
        if "out_scores" in kwargs:
            postprocess_kwargs["out_scores"] = kwargs["out_scores"]

        return {}, {}, postprocess_kwargs


    def preprocess(self, inputs):
        sentences = inputs["sentences"]

        # Tokenization and chunking
        doc_tokens = self.tokenizer.tokenize( f"{self.tokenizer.sep_token}{self.tokenizer.cls_token}".join(sentences) )
        doc_tokens = [self.tokenizer.cls_token] + doc_tokens + [self.tokenizer.sep_token]
        doc_chunks = splitDocument(doc_tokens, self.tokenizer.cls_token, self.tokenizer.sep_token, self.model.config.input_size)
        
        # Batch preparation
        batch = {
            "ids": [],
            "clss_mask": [],
            "attn_mask": [],
            "global_attn_mask": [],
        }
        for chunk_tokens in doc_chunks:
            doc_ids = self.tokenizer.convert_tokens_to_ids(chunk_tokens)
            clss_mask = [True if token == self.tokenizer.cls_token_id else False for token in doc_ids]
            attn_mask = [1 for _ in range(len(doc_ids))]
            global_attn_mask = [1 if token == self.tokenizer.cls_token_id else 0 for token in doc_ids]

            batch["ids"].append( padToSize(doc_ids, self.model.config.input_size, self.tokenizer.pad_token_id) )
            batch["clss_mask"].append( padToSize(clss_mask, self.model.config.input_size, False) )
            batch["attn_mask"].append( padToSize(attn_mask, self.model.config.input_size, 0) )
            batch["global_attn_mask"].append( padToSize(global_attn_mask, self.model.config.input_size, 0) )

        batch["ids"] = torch.as_tensor(batch["ids"])
        batch["clss_mask"] = torch.as_tensor(batch["clss_mask"])
        batch["attn_mask"] = torch.as_tensor(batch["attn_mask"])
        batch["global_attn_mask"] = torch.as_tensor(batch["global_attn_mask"])
        return { "inputs": batch, "sentences": sentences }


    def _forward(self, args):
        batch = args["inputs"]
        sentences = args["sentences"]
        out_predictions = torch.as_tensor([]).to(self.device)

        self.model.eval()
        with torch.no_grad():
            batch_preds, _ = self.model(batch)
            for i, clss_mask in enumerate(batch["clss_mask"]):
                out_predictions = torch.cat((out_predictions, batch_preds[i][:torch.sum(clss_mask == True)]))

        return { "predictions": out_predictions, "sentences": sentences }


    def postprocess(self, args, strategy: str="count", strategy_args=3, out_scores=False):
        predictions = args["predictions"]
        sentences = args["sentences"]
        out = select(sentences, predictions, strategy, strategy_args)

        if out_scores: out += (predictions,)

        return out