File size: 3,611 Bytes
7c3c80c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
from typing import List, Optional, Union
from tqdm import tqdm
from collections import defaultdict



class BM25Retriever:
    def __init__(self, k1:float=0.9, b:float=0.4) -> None:
        self.k1 = k1
        self.b = b

    def index(self, corpus: List[Union[str, List[int]]], verbose: bool=False, stop_tokens: Optional[set]=None):
        """Build in-memory BM25 index."""
        if stop_tokens is None:
            stop_tokens = {}

        dfs = defaultdict(int)
        tfs = []
        inverted_lists = defaultdict(list)
        doc_lengths = np.zeros(len(corpus), dtype=np.float32)

        if verbose:
            iterator = tqdm(corpus, desc="Indexing")
        else:
            iterator = corpus

        for i, doc in enumerate(iterator):
            if isinstance(doc, str):
                doc = doc.split(" ")
            df = {}
            tf = defaultdict(int)
            for token in doc:
                if token not in stop_tokens:
                    tf[token] += 1
                    df[token] = 1
            tfs.append(dict(tf))
            for token in df:
                dfs[token] += 1
                # store the doc offset in the inverted lists of the corresponding token
                inverted_lists[token].append(i)

            doc_lengths[i] = len(doc)

        self.dfs = dict(dfs)
        self.tfs = tfs
        self.doc_length = doc_lengths
        self.inverted_lists = {k: np.array(v) for k, v in inverted_lists.items()}
        self.N = len(corpus)

    def search(self, queries: Union[str, List[int], List[str], List[List[int]]], hits: int=100, k1: Optional[float]=None, b: Optional[float]=None, verbose: bool=False):
        """Search over the BM25 index."""
        if k1 is None:
            k1 = self.k1
        if b is None:
            b = self.b
        
        hits = min(self.N, hits)
        
        global_scores = np.zeros(self.N, dtype=np.float32)
        
        if isinstance(queries, str):
            queries = [queries]
        elif isinstance(queries, list) and isinstance(queries[0], int):
            queries = [queries]
        
        all_scores = np.zeros((len(queries), hits), dtype=np.float32)
        all_indices = np.zeros((len(queries), hits), dtype=np.int64)

        if verbose:
            iterator = tqdm(queries, desc="Searching")
        else:
            iterator = queries
        
        for i, query in enumerate(iterator):
            if isinstance(query, str):
                query = query.split(" ")
                # TODO: stem

            for token in query:
                if token in self.inverted_lists:
                    candidates = self.inverted_lists[token]
                else:
                    continue

                tfs = np.array([self.tfs[candidate][token] for candidate in candidates], dtype=np.float32)
                df = self.dfs[token]
                idf = np.log((self.N - df + 0.5) / (df + 0.5) + 1)

                candidate_scores = idf * (k1 + 1) * tfs / (tfs + k1 * (1 - b + b * self.doc_length[candidates]))
                global_scores[candidates] += candidate_scores

            indice = np.argpartition(-global_scores, hits - 1)[:hits]
            score = global_scores[indice]
            
            sorted_idx = np.argsort(score)[::-1]
            indice = indice[sorted_idx]
            score = score[sorted_idx]

            invalid_pos = score == 0
            indice[invalid_pos] = -1
            score[invalid_pos] = -float('inf')

            all_scores[i] = score
            all_indices[i] = indice
        return all_scores, all_indices