File size: 7,264 Bytes
6158da4
 
 
 
 
 
 
6d056d5
6158da4
 
 
a052bdc
6158da4
57b7b8d
6d056d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6158da4
 
 
 
 
6d056d5
6158da4
b83cc65
6158da4
 
 
 
 
 
 
a052bdc
6158da4
 
 
 
 
 
f0018f2
 
 
6d056d5
 
 
 
 
 
 
f0018f2
 
 
 
 
6158da4
6d056d5
 
57b7b8d
 
 
 
6d056d5
6158da4
6d056d5
6158da4
 
57b7b8d
6158da4
 
 
 
 
 
 
 
57b7b8d
6158da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.llms import CTransformers
from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
import os
from modules.constants import *
from modules.helpers import get_prompt
from modules.chat_model_loader import ChatModelLoader
from modules.vector_db import VectorDB, VectorDBScore
from typing import Dict, Any, Optional
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
import inspect
from langchain.chains.conversational_retrieval.base import _get_chat_history


class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
        question = inputs["question"]
        get_chat_history = self.get_chat_history or _get_chat_history
        chat_history_str = get_chat_history(inputs["chat_history"])
        print(f"chat_history_str: {chat_history_str}")
        if chat_history_str:
            callbacks = _run_manager.get_child()
            new_question = await self.question_generator.arun(
                question=question, chat_history=chat_history_str, callbacks=callbacks
            )
        else:
            new_question = question
        accepts_run_manager = (
            "run_manager" in inspect.signature(self._aget_docs).parameters
        )
        if accepts_run_manager:
            docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
        else:
            docs = await self._aget_docs(new_question, inputs)  # type: ignore[call-arg]

        output: Dict[str, Any] = {}
        if self.response_if_no_docs_found is not None and len(docs) == 0:
            output[self.output_key] = self.response_if_no_docs_found
        else:
            new_inputs = inputs.copy()
            if self.rephrase_question:
                new_inputs["question"] = new_question
            new_inputs["chat_history"] = chat_history_str

            # Prepare the final prompt with metadata
            context = "\n\n".join(
                [
                    f"Document content: {doc.page_content}\nMetadata: {doc.metadata}"
                    for doc in docs
                ]
            )
            final_prompt = f"""
                You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Use the following pieces of information to answer the user's question. 
                If you don't know the answer, just say that you don't know—don't try to make up an answer. 
                Use the chat history to answer the question only if it's relevant; otherwise, ignore it. The context for the answer will be under "Document context:". 
                Use the metadata from each document to guide the user to the correct sources. 
                The context is ordered by relevance to the question. Give more weight to the most relevant documents.
                Talk in a friendly and personalized manner, similar to how you would speak to a friend who needs help. Make the conversation engaging and avoid sounding repetitive or robotic.

                Chat History:
                {chat_history_str}

                Context:
                {context}

                Question: {new_question}
                AI Tutor:
                """

            new_inputs["input"] = final_prompt
            new_inputs["question"] = final_prompt
            output["final_prompt"] = final_prompt

            answer = await self.combine_docs_chain.arun(
                input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
            )
            output[self.output_key] = answer

        if self.return_source_documents:
            output["source_documents"] = docs
        if self.return_generated_question:
            output["generated_question"] = new_question
        return output


class LLMTutor:
    def __init__(self, config, logger=None):
        self.config = config
        self.llm = self.load_llm()
        self.vector_db = VectorDB(config, logger=logger)
        if self.config["embedding_options"]["embedd_files"]:
            self.vector_db.create_database()
            self.vector_db.save_database()

    def set_custom_prompt(self):
        """
        Prompt template for QA retrieval for each vectorstore
        """
        prompt = get_prompt(self.config)
        # prompt = QA_PROMPT

        return prompt

    # Retrieval QA Chain
    def retrieval_qa_chain(self, llm, prompt, db):
        if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
            retriever = VectorDBScore(
                vectorstore=db,
                # search_type="similarity_score_threshold",
                # search_kwargs={
                #     "score_threshold": self.config["embedding_options"][
                #         "score_threshold"
                #     ],
                #     "k": self.config["embedding_options"]["search_top_k"],
                # },
            )
        elif self.config["embedding_options"]["db_option"] == "RAGatouille":
            retriever = db.as_langchain_retriever(
                k=self.config["embedding_options"]["search_top_k"]
            )
        if self.config["llm_params"]["use_history"]:
            memory = ConversationSummaryBufferMemory(
                llm = llm,
                k=self.config["llm_params"]["memory_window"],
                memory_key="chat_history",
                return_messages=True,
                output_key="answer",
                max_token_limit=128,
            )
            qa_chain = CustomConversationalRetrievalChain.from_llm(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                memory=memory,
                combine_docs_chain_kwargs={"prompt": prompt},
            )
        else:
            qa_chain = RetrievalQA.from_chain_type(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                chain_type_kwargs={"prompt": prompt},
            )
        return qa_chain

    # Loading the model
    def load_llm(self):
        chat_model_loader = ChatModelLoader(self.config)
        llm = chat_model_loader.load_chat_model()
        return llm

    # QA Model Function
    def qa_bot(self):
        db = self.vector_db.load_database()
        qa_prompt = self.set_custom_prompt()
        qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)

        return qa

    # output function
    def final_result(query):
        qa_result = qa_bot()
        response = qa_result({"query": query})
        return response