from langchain import PromptTemplate from langchain.embeddings import HuggingFaceEmbeddings from langchain_community.chat_models import ChatOpenAI from langchain_community.embeddings import OpenAIEmbeddings from langchain.vectorstores import FAISS from langchain.chains import RetrievalQA, ConversationalRetrievalChain from langchain.llms import CTransformers from langchain.memory import ConversationBufferWindowMemory from langchain.chains.conversational_retrieval.prompts import QA_PROMPT import os from modules.constants import * from modules.helpers import get_prompt from modules.chat_model_loader import ChatModelLoader from modules.vector_db import VectorDB, VectorDBScore class LLMTutor: def __init__(self, config, logger=None): self.config = config self.vector_db = VectorDB(config, logger=logger) if self.config["embedding_options"]["embedd_files"]: self.vector_db.create_database() self.vector_db.save_database() def set_custom_prompt(self): """ Prompt template for QA retrieval for each vectorstore """ prompt = get_prompt(self.config) # prompt = QA_PROMPT return prompt # Retrieval QA Chain def retrieval_qa_chain(self, llm, prompt, db): if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]: retriever = VectorDBScore( vectorstore=db, search_type="similarity_score_threshold", search_kwargs={ "score_threshold": self.config["embedding_options"][ "score_threshold" ], "k": self.config["embedding_options"]["search_top_k"], }, ) elif self.config["embedding_options"]["db_option"] == "RAGatouille": retriever = db.as_langchain_retriever( k=self.config["embedding_options"]["search_top_k"] ) if self.config["llm_params"]["use_history"]: memory = ConversationBufferWindowMemory( k=self.config["llm_params"]["memory_window"], memory_key="chat_history", return_messages=True, output_key="answer", ) qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, memory=memory, combine_docs_chain_kwargs={"prompt": prompt}, ) else: qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": prompt}, ) return qa_chain # Loading the model def load_llm(self): chat_model_loader = ChatModelLoader(self.config) llm = chat_model_loader.load_chat_model() return llm # QA Model Function def qa_bot(self): db = self.vector_db.load_database() self.llm = self.load_llm() qa_prompt = self.set_custom_prompt() qa = self.retrieval_qa_chain(self.llm, qa_prompt, db) return qa # output function def final_result(query): qa_result = qa_bot() response = qa_result({"query": query}) return response