File size: 3,460 Bytes
6158da4
 
 
 
 
 
 
db6b619
6158da4
 
 
a052bdc
6158da4
57b7b8d
6158da4
 
 
 
 
 
b83cc65
6158da4
 
 
 
 
 
 
a052bdc
6158da4
 
 
 
 
 
f0018f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6158da4
db6b619
57b7b8d
 
 
 
6158da4
 
 
 
57b7b8d
6158da4
 
 
 
 
 
 
 
57b7b8d
6158da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.llms import CTransformers
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
import os
from modules.constants import *
from modules.helpers import get_prompt
from modules.chat_model_loader import ChatModelLoader
from modules.vector_db import VectorDB, VectorDBScore


class LLMTutor:
    def __init__(self, config, logger=None):
        self.config = config
        self.vector_db = VectorDB(config, logger=logger)
        if self.config["embedding_options"]["embedd_files"]:
            self.vector_db.create_database()
            self.vector_db.save_database()

    def set_custom_prompt(self):
        """
        Prompt template for QA retrieval for each vectorstore
        """
        prompt = get_prompt(self.config)
        # prompt = QA_PROMPT

        return prompt

    # Retrieval QA Chain
    def retrieval_qa_chain(self, llm, prompt, db):
        if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
            retriever = VectorDBScore(
                vectorstore=db,
                search_type="similarity_score_threshold",
                search_kwargs={
                    "score_threshold": self.config["embedding_options"][
                        "score_threshold"
                    ],
                    "k": self.config["embedding_options"]["search_top_k"],
                },
            )
        elif self.config["embedding_options"]["db_option"] == "RAGatouille":
            retriever = db.as_langchain_retriever(
                k=self.config["embedding_options"]["search_top_k"]
            )
        if self.config["llm_params"]["use_history"]:
            memory = ConversationBufferWindowMemory(
                k=self.config["llm_params"]["memory_window"],
                memory_key="chat_history",
                return_messages=True,
                output_key="answer",
            )
            qa_chain = ConversationalRetrievalChain.from_llm(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                memory=memory,
                combine_docs_chain_kwargs={"prompt": prompt},
            )
        else:
            qa_chain = RetrievalQA.from_chain_type(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                chain_type_kwargs={"prompt": prompt},
            )
        return qa_chain

    # Loading the model
    def load_llm(self):
        chat_model_loader = ChatModelLoader(self.config)
        llm = chat_model_loader.load_chat_model()
        return llm

    # QA Model Function
    def qa_bot(self):
        db = self.vector_db.load_database()
        self.llm = self.load_llm()
        qa_prompt = self.set_custom_prompt()
        qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)

        return qa

    # output function
    def final_result(query):
        qa_result = qa_bot()
        response = qa_result({"query": query})
        return response