import json import os import re from langchain.chains import ConversationChain, LLMChain from langchain.prompts import PromptTemplate from langchain.chains.base import Chain from app_modules.llm_inference import LLMInference from app_modules.utils import CustomizedConversationSummaryBufferMemory from langchain.chains import LLMChain from langchain.globals import get_debug chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true" B_INST, E_INST = "[INST]", "[/INST]" def get_system_prompt_and_user_message(orca=False): # system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." system_prompt = ( "You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior." if orca else "You are a chatbot having a conversation with a human." ) user_message = "{input}" if chat_history_enabled: user_message = "Chat History:\n\n{history} \n\n" + user_message system_prompt += " Read the chat history to get context." return system_prompt, user_message def create_llama_2_prompt_template(): B_SYS, E_SYS = "<>\n", "\n<>\n\n" system_prompt, user_message = get_system_prompt_and_user_message() SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS prompt_template = B_INST + SYSTEM_PROMPT + user_message + E_INST return prompt_template def create_llama_3_prompt_template(): system_prompt, user_message = get_system_prompt_and_user_message() prompt_template = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> { system_prompt }<|eot_id|><|start_header_id|>user<|end_header_id|> { user_message }<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ return prompt_template def create_phi_3_prompt_template(): system_prompt, user_message = get_system_prompt_and_user_message() prompt_template = f"""<|system|> { system_prompt }<|end|> <|user|> { user_message }<|end|> <|assistant|> """ return prompt_template def create_orca_2_prompt_template(): system_prompt, user_message = get_system_prompt_and_user_message(orca=False) prompt_template = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" return prompt_template def create_mistral_prompt_template(): system_prompt, user_message = get_system_prompt_and_user_message() prompt_template = B_INST + system_prompt + "\n\n" + user_message + E_INST return prompt_template def create_gemma_prompt_template(): return "user\n{input}\nmodel\n" def create_prompt_template(model_name): print(f"creating prompt template for model: {model_name}") if re.search(r"llama-?2", model_name, re.IGNORECASE): return create_llama_2_prompt_template() elif re.search(r"llama-?3", model_name, re.IGNORECASE): return create_llama_3_prompt_template() elif re.search(r"phi-?3", model_name, re.IGNORECASE): return create_phi_3_prompt_template() elif model_name.lower().startswith("orca"): return create_orca_2_prompt_template() elif model_name.lower().startswith("mistral"): return create_mistral_prompt_template() elif model_name.lower().startswith("gemma"): return create_gemma_prompt_template() return ( """You are a chatbot having a conversation with a human. {history} Human: {input} Chatbot:""" if chat_history_enabled else """You are a chatbot having a conversation with a human. Human: {input} Chatbot:""" ) class ChatChain(LLMInference): def __init__(self, llm_loader): super().__init__(llm_loader) def create_chain(self) -> Chain: template = create_prompt_template(self.llm_loader.model_name) print(f"template: {template}") if chat_history_enabled: prompt = PromptTemplate( input_variables=["history", "input"], template=template ) memory = CustomizedConversationSummaryBufferMemory( llm=self.llm_loader.llm, max_token_limit=1024, return_messages=False ) llm_chain = ConversationChain( llm=self.llm_loader.llm, prompt=prompt, verbose=False, memory=memory, ) else: prompt = PromptTemplate(input_variables=["input"], template=template) llm_chain = LLMChain(llm=self.llm_loader.llm, prompt=prompt) return llm_chain def _process_inputs(self, inputs): if not isinstance(inputs, list): inputs = {"input": inputs["question"]} elif self.llm_loader.llm_model_type == "huggingface": inputs = [ [ { "role": "system", "content": self.get_system_message(i), }, { "role": "user", "content": self.get_user_message(i), }, ] for i in inputs ] else: inputs = [{"input": i["question"]} for i in inputs] if get_debug(): print("_process_inputs:", json.dumps(inputs, indent=4)) return inputs def get_system_message(self, input) -> Chain: return get_system_prompt_and_user_message()[0] def get_user_message(self, input) -> Chain: return input["question"]