import os from langchain import LLMChain, PromptTemplate from langchain.chains import ConversationalRetrievalChain from langchain.chains.base import Chain from langchain.memory import ConversationBufferMemory from app_modules.llm_inference import LLMInference def get_llama_2_prompt_template(): B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" instruction = "Chat History:\n\n{chat_history} \n\nUser: {question}" system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. read the chat history to get context" SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST return prompt_template class ChatChain(LLMInference): def __init__(self, llm_loader): super().__init__(llm_loader) def create_chain(self) -> Chain: template = ( get_llama_2_prompt_template() if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true" else """You are a chatbot having a conversation with a human. {chat_history} Human: {question} Chatbot:""" ) print(f"template: {template}") prompt = PromptTemplate( input_variables=["chat_history", "question"], template=template ) memory = ConversationBufferMemory(memory_key="chat_history") llm_chain = LLMChain( llm=self.llm_loader.llm, prompt=prompt, verbose=True, memory=memory, ) return llm_chain