File size: 1,631 Bytes
f51bb92
 
 
 
 
8f6647c
 
 
 
 
 
 
f51bb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6647c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from langchain_community.vectorstores import FAISS
from modules.vectorstore.base import VectorStoreBase
import os


class FAISS(FAISS):
    """To add length property to FAISS class"""

    def __len__(self):
        return self.index.ntotal


class FaissVectorStore(VectorStoreBase):
    def __init__(self, config):
        self.config = config
        self._init_vector_db()

    def _init_vector_db(self):
        self.faiss = FAISS(
            embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
        )

    def create_database(self, document_chunks, embedding_model):
        self.vectorstore = self.faiss.from_documents(
            documents=document_chunks, embedding=embedding_model
        )
        self.vectorstore.save_local(
            os.path.join(
                self.config["vectorstore"]["db_path"],
                "db_"
                + self.config["vectorstore"]["db_option"]
                + "_"
                + self.config["vectorstore"]["model"],
            )
        )

    def load_database(self, embedding_model):
        self.vectorstore = self.faiss.load_local(
            os.path.join(
                self.config["vectorstore"]["db_path"],
                "db_"
                + self.config["vectorstore"]["db_option"]
                + "_"
                + self.config["vectorstore"]["model"],
            ),
            embedding_model,
            allow_dangerous_deserialization=True,
        )
        return self.vectorstore

    def as_retriever(self):
        return self.vectorstore.as_retriever()

    def __len__(self):
        return len(self.vectorstore)