File size: 3,446 Bytes
6f408cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch


def _selectStrategyLength(sentences, predictions, max_length):
    selected_sents = []
    sents_priority = torch.argsort(predictions, descending=True)
    summary_len = 0
    i = 0

    while (summary_len < max_length) and (i < len(sents_priority)):
        if summary_len + len(sentences[sents_priority[i]]) < max_length:
            selected_sents.append(sents_priority[i].item())
            summary_len += len(sentences[sents_priority[i]])
        i += 1

    return sorted(selected_sents)


def _selectStrategyCount(sentences, predictions, num_sents):
    selected_idxs = sorted(torch.topk(predictions, min(len(predictions), num_sents)).indices)
    return [tensor.item() for tensor in selected_idxs]


def _selectStrategyRatio(sentences, predictions, ratio):
    doc_length = sum([ len(sent) for sent in sentences ])
    return _selectStrategyLength(sentences, predictions, doc_length*ratio)


def _selectStrategyThreshold(sentences, predictions, threshold):
    return [i for i, score in enumerate(predictions) if score >= threshold]


def select(sentences, predictions, strategy, strategy_args):
    selected_sents = []

    if strategy == "length":
        selected_sents = _selectStrategyLength(sentences, predictions, strategy_args)
    elif strategy == "count":
        selected_sents = _selectStrategyCount(sentences, predictions, strategy_args)
    elif strategy == "ratio":
        selected_sents = _selectStrategyRatio(sentences, predictions, strategy_args)
    elif strategy == "threshold":
        selected_sents = _selectStrategyThreshold(sentences, predictions, strategy_args)
    else:
        raise NotImplementedError(f"Unknown strategy {strategy}")
    
    return [sentences[i] for i in selected_sents], selected_sents



"""
    Splits a document in chunks of maximum a given size.

    Parameters
    ----------
        doc_tokens : str[]
            List of the tokens of the document.

        bos_token : str
            Begin of sentence token.

        eos_token : str
            End of sentence token.

        max_size : int
            Maximum size of a chunk.
    Returns
    -------
        chunks : str[][]
            Splitted document.
"""
def splitDocument(doc_tokens, bos_token, eos_token, max_size):
    def _findNextBOSFrom(start_idx):
        for i in range(start_idx, len(doc_tokens)):
            if doc_tokens[i] == bos_token:
                return i
        return -1
    
    def _findPreviousEOSFrom(start_idx):
        for i in range(start_idx, -1, -1):
            if doc_tokens[i] == eos_token:
                return i
        return -1
    
    chunks = []
    
    while len(doc_tokens) > max_size:
        # Splits at the eos token
        eos_idx = _findPreviousEOSFrom(max_size - 1)

        if eos_idx == -1: 
            # The sentence is too long.
            # Find the next bos in front of the current sentence (if exists) and truncate the current sentence.
            next_bos_idx = _findNextBOSFrom(max_size)
            if next_bos_idx != -1:
                doc_tokens = doc_tokens[:max_size-1] + [eos_token] + doc_tokens[next_bos_idx:]
            else:
                doc_tokens = doc_tokens[:max_size-1] + [eos_token]
            eos_idx = max_size - 1

        chunks.append(doc_tokens[:eos_idx+1])
        doc_tokens = doc_tokens[eos_idx+1:]

    if len(doc_tokens) > 0: chunks.append(doc_tokens) # Remaining part of the document
    
    return chunks