File size: 4,943 Bytes
83870cc
51dabd6
 
be1f224
0157dfd
be1f224
 
51a31d4
 
be1f224
51a31d4
be1f224
 
0157dfd
51a31d4
be1f224
 
51dabd6
b06298d
0157dfd
b7158e7
b06298d
8bbe3aa
51a31d4
0157dfd
83870cc
51a31d4
be1f224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0157dfd
be1f224
0157dfd
 
be1f224
 
 
 
 
 
 
 
 
 
 
ab5dfc2
8bbe3aa
 
 
 
be1f224
 
8bbe3aa
 
be1f224
 
8bbe3aa
be1f224
 
8bbe3aa
 
be1f224
 
8bbe3aa
e9df5ab
be1f224
1fb8ae3
 
8bbe3aa
be1f224
 
 
492106d
 
be1f224
 
 
 
 
 
 
 
 
 
 
492106d
be1f224
 
 
 
 
 
 
1fb8ae3
ab5dfc2
 
1fb8ae3
e9df5ab
b7158e7
 
1fb8ae3
 
 
83870cc
 
 
be1f224
83870cc
1fb8ae3
8bbe3aa
83870cc
ab5dfc2
1fb8ae3
 
8bbe3aa
1fb8ae3
8bbe3aa
b06298d
be1f224
1fb8ae3
83870cc
 
2827202
8bbe3aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import os.path
import torch

from dotenv import load_dotenv
from datasets import DatasetDict
from dataclasses import dataclass
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizerFast,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizerFast,
    LongformerModel,
    LongformerTokenizer
)
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from src.retrievers.base_retriever import RetrieveType, Retriever
from src.utils.log import logger
from src.utils.preprocessing import remove_formulas
from src.utils.timing import timeit


load_dotenv()


@dataclass
class FaissRetrieverOptions:
    ctx_encoder: PreTrainedModel
    ctx_tokenizer: PreTrainedTokenizerFast
    q_encoder: PreTrainedModel
    q_tokenizer: PreTrainedTokenizerFast
    embedding_path: str
    lm: str

    @staticmethod
    def dpr(embedding_path: str):
        return FaissRetrieverOptions(
            ctx_encoder=DPRContextEncoder.from_pretrained(
                "facebook/dpr-ctx_encoder-single-nq-base"
            ),
            ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
                "facebook/dpr-ctx_encoder-single-nq-base"
            ),
            q_encoder=DPRQuestionEncoder.from_pretrained(
                "facebook/dpr-question_encoder-single-nq-base"
            ),
            q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
                "facebook/dpr-question_encoder-single-nq-base"
            ),
            embedding_path=embedding_path,
            lm="dpr"
        )

    @staticmethod
    def longformer(embedding_path: str):
        encoder = LongformerModel.from_pretrained(
            "valhalla/longformer-base-4096-finetuned-squadv1"
        )
        tokenizer = LongformerTokenizer.from_pretrained(
            "valhalla/longformer-base-4096-finetuned-squadv1"
        )
        return FaissRetrieverOptions(
            ctx_encoder=encoder,
            ctx_tokenizer=tokenizer,
            q_encoder=encoder,
            q_tokenizer=tokenizer,
            embedding_path=embedding_path,
            lm="longformer"
        )


class FaissRetriever(Retriever):
    """A class used to retrieve relevant documents based on some query.
    based on https://huggingface.co/docs/datasets/faiss_es#faiss.
    """

    def __init__(self, paragraphs: DatasetDict,
                 options: FaissRetrieverOptions) -> None:
        torch.set_grad_enabled(False)

        self.lm = options.lm

        # Context encoding and tokenization
        self.ctx_encoder = options.ctx_encoder
        self.ctx_tokenizer = options.ctx_tokenizer

        # Question encoding and tokenization
        self.q_encoder = options.q_encoder
        self.q_tokenizer = options.q_tokenizer

        self.paragraphs = paragraphs
        self.embedding_path = options.embedding_path

        self.index = self._init_index()

    def _embed_question(self, q):
        match self.lm:
            case "dpr":
                tok = self.q_tokenizer(
                    q, return_tensors="pt", truncation=True, padding=True)
                return self.q_encoder(**tok)[0][0].numpy()
            case "longformer":
                tok = self.q_tokenizer(q, return_tensors="pt")
                return self.q_encoder(**tok).last_hidden_state[0][0].numpy()

    def _embed_context(self, row):
        p = row["text"]

        match self.lm:
            case "dpr":
                tok = self.ctx_tokenizer(
                    p, return_tensors="pt", truncation=True, padding=True)
                enc = self.ctx_encoder(**tok)[0][0].numpy()
                return {"embeddings": enc}
            case "longformer":
                tok = self.ctx_tokenizer(p, return_tensors="pt")
                enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
                return {"embeddings": enc}

    def _init_index(
            self,
            force_new_embedding: bool = False):

        ds = self.paragraphs["train"]
        ds = ds.map(remove_formulas)

        if not force_new_embedding and os.path.exists(self.embedding_path):
            ds.load_faiss_index(
                'embeddings', self.embedding_path)  # type: ignore
            return ds
        else:
            # Add FAISS embeddings
            index = ds.map(self._embed_context)  # type: ignore

            index.add_faiss_index(column="embeddings")

            # save dataset w/ embeddings
            os.makedirs("./src/models/", exist_ok=True)
            index.save_faiss_index(
                "embeddings", self.embedding_path)

            return index

    def retrieve(self, query: str, k: int = 5) -> RetrieveType:
        question_embedding = self._embed_question(query)
        scores, results = self.index.get_nearest_examples(
            "embeddings", question_embedding, k=k
        )

        return scores, results