Farid Karimli
Initial streaming implementation
4de6b1a
raw
history blame
5.63 kB
from langchain_core.prompts import ChatPromptTemplate
from modules.chat.langchain.utils import *
from langchain.memory import ChatMessageHistory
from modules.chat.base import BaseRAG
class Langchain_RAG(BaseRAG):
def __init__(self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str):
"""
Initialize the Langchain_RAG class.
Args:
llm (LanguageModelLike): The language model instance.
memory (BaseChatMessageHistory): The chat message history instance.
retriever (BaseRetriever): The retriever instance.
qa_prompt (str): The QA prompt string.
rephrase_prompt (str): The rephrase prompt string.
"""
self.llm = llm
self.memory = self.add_history_from_list(memory)
self.retriever = retriever
self.qa_prompt = qa_prompt
self.rephrase_prompt = rephrase_prompt
self.store = {}
# Contextualize question prompt
contextualize_q_system_prompt = rephrase_prompt or (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, just "
"reformulate it if needed and otherwise return it as is."
)
self.contextualize_q_prompt = ChatPromptTemplate.from_template(
contextualize_q_system_prompt
)
# History-aware retriever
self.history_aware_retriever = create_history_aware_retriever(
self.llm, self.retriever, self.contextualize_q_prompt
)
# Answer question prompt
qa_system_prompt = qa_prompt or (
"You are an assistant for question-answering tasks. Use "
"the following pieces of retrieved context to answer the "
"question. If you don't know the answer, just say that you "
"don't know. Use three sentences maximum and keep the answer "
"concise."
"\n\n"
"{context}"
)
self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt)
# Question-answer chain
self.question_answer_chain = create_stuff_documents_chain(
self.llm, self.qa_prompt_template
)
# Final retrieval chain
self.rag_chain = create_retrieval_chain(
self.history_aware_retriever, self.question_answer_chain
)
self.rag_chain = CustomRunnableWithHistory(
self.rag_chain,
get_session_history=self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="memory_window",
annotation=int,
name="Number of Conversations",
description="Number of conversations to consider for context.",
default=1,
is_shared=True,
),
],
)
def get_session_history(
self, user_id: str, conversation_id: str, memory_window: int
) -> BaseChatMessageHistory:
"""
Get the session history for a user and conversation.
Args:
user_id (str): The user identifier.
conversation_id (str): The conversation identifier.
memory_window (int): The number of conversations to consider for context.
Returns:
BaseChatMessageHistory: The chat message history.
"""
if (user_id, conversation_id) not in self.store:
self.store[(user_id, conversation_id)] = InMemoryHistory()
self.store[(user_id, conversation_id)].add_messages(
self.memory.messages
) # add previous messages to the store. Note: the store is in-memory.
return self.store[(user_id, conversation_id)]
def invoke(self, user_query, config):
"""
Invoke the chain.
Args:
kwargs: The input variables.
Returns:
dict: The output variables.
"""
res = self.rag_chain.invoke(user_query, config)
res["rephrase_prompt"] = self.rephrase_prompt
res["qa_prompt"] = self.qa_prompt
return res
def stream(self, user_query, config):
res = self.rag_chain.stream(user_query, config)
return res
def add_history_from_list(self, history_list):
"""
Add messages from a list to the chat history.
Args:
messages (list): The list of messages to add.
"""
history = ChatMessageHistory()
for idx, message_pairs in enumerate(history_list):
history.add_user_message(message_pairs[0])
history.add_ai_message(message_pairs[1])
return history