File size: 969 Bytes
a2ac5f7
 
 
f2daaee
a2ac5f7
 
 
 
 
ea7b686
 
 
 
f2daaee
ea7b686
a2ac5f7
 
 
ea7b686
 
 
 
 
a2ac5f7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from modules.retriever.faiss_retriever import FaissRetriever
from modules.retriever.chroma_retriever import ChromaRetriever
from modules.retriever.colbert_retriever import ColbertRetriever
from modules.retriever.raptor_retriever import RaptorRetriever


class Retriever:
    def __init__(self, config):
        self.config = config
        self.retriever_classes = {
            "FAISS": FaissRetriever,
            "Chroma": ChromaRetriever,
            "RAGatouille": ColbertRetriever,
            "RAPTOR": RaptorRetriever,
        }
        self._create_retriever()

    def _create_retriever(self):
        db_option = self.config["vectorstore"]["db_option"]
        retriever_class = self.retriever_classes.get(db_option)
        if not retriever_class:
            raise ValueError(f"Invalid db_option: {db_option}")
        self.retriever = retriever_class()

    def _return_retriever(self, db):
        return self.retriever.return_retriever(db, self.config)