|
import torch |
|
|
|
|
|
def _selectStrategyLength(sentences, predictions, max_length): |
|
selected_sents = [] |
|
sents_priority = torch.argsort(predictions, descending=True) |
|
summary_len = 0 |
|
i = 0 |
|
|
|
while (summary_len < max_length) and (i < len(sents_priority)): |
|
if summary_len + len(sentences[sents_priority[i]]) < max_length: |
|
selected_sents.append(sents_priority[i].item()) |
|
summary_len += len(sentences[sents_priority[i]]) |
|
i += 1 |
|
|
|
return sorted(selected_sents) |
|
|
|
|
|
def _selectStrategyCount(sentences, predictions, num_sents): |
|
selected_idxs = sorted(torch.topk(predictions, min(len(predictions), num_sents)).indices) |
|
return [tensor.item() for tensor in selected_idxs] |
|
|
|
|
|
def _selectStrategyRatio(sentences, predictions, ratio): |
|
doc_length = sum([ len(sent) for sent in sentences ]) |
|
return _selectStrategyLength(sentences, predictions, doc_length*ratio) |
|
|
|
|
|
def _selectStrategyThreshold(sentences, predictions, threshold): |
|
return [i for i, score in enumerate(predictions) if score >= threshold] |
|
|
|
|
|
def select(sentences, predictions, strategy, strategy_args): |
|
selected_sents = [] |
|
|
|
if strategy == "length": |
|
selected_sents = _selectStrategyLength(sentences, predictions, strategy_args) |
|
elif strategy == "count": |
|
selected_sents = _selectStrategyCount(sentences, predictions, strategy_args) |
|
elif strategy == "ratio": |
|
selected_sents = _selectStrategyRatio(sentences, predictions, strategy_args) |
|
elif strategy == "threshold": |
|
selected_sents = _selectStrategyThreshold(sentences, predictions, strategy_args) |
|
else: |
|
raise NotImplementedError(f"Unknown strategy {strategy}") |
|
|
|
return [sentences[i] for i in selected_sents], selected_sents |
|
|
|
|
|
|
|
""" |
|
Splits a document in chunks of maximum a given size. |
|
|
|
Parameters |
|
---------- |
|
doc_tokens : str[] |
|
List of the tokens of the document. |
|
|
|
bos_token : str |
|
Begin of sentence token. |
|
|
|
eos_token : str |
|
End of sentence token. |
|
|
|
max_size : int |
|
Maximum size of a chunk. |
|
Returns |
|
------- |
|
chunks : str[][] |
|
Splitted document. |
|
""" |
|
def splitDocument(doc_tokens, bos_token, eos_token, max_size): |
|
def _findNextBOSFrom(start_idx): |
|
for i in range(start_idx, len(doc_tokens)): |
|
if doc_tokens[i] == bos_token: |
|
return i |
|
return -1 |
|
|
|
def _findPreviousEOSFrom(start_idx): |
|
for i in range(start_idx, -1, -1): |
|
if doc_tokens[i] == eos_token: |
|
return i |
|
return -1 |
|
|
|
chunks = [] |
|
|
|
while len(doc_tokens) > max_size: |
|
|
|
eos_idx = _findPreviousEOSFrom(max_size - 1) |
|
|
|
if eos_idx == -1: |
|
|
|
|
|
next_bos_idx = _findNextBOSFrom(max_size) |
|
if next_bos_idx != -1: |
|
doc_tokens = doc_tokens[:max_size-1] + [eos_token] + doc_tokens[next_bos_idx:] |
|
else: |
|
doc_tokens = doc_tokens[:max_size-1] + [eos_token] |
|
eos_idx = max_size - 1 |
|
|
|
chunks.append(doc_tokens[:eos_idx+1]) |
|
doc_tokens = doc_tokens[eos_idx+1:] |
|
|
|
if len(doc_tokens) > 0: chunks.append(doc_tokens) |
|
|
|
return chunks |