Spaces:
Build error
Build error
XThomasBU
commited on
Commit
•
e165ea5
1
Parent(s):
9c5bb39
updates
Browse files- code/main.py +3 -1
- code/modules/chat/langchain/langchain_rag.py +78 -4
- code/modules/chat/langchain/utils.py +134 -16
- code/modules/chat/llm_tutor.py +3 -2
code/main.py
CHANGED
@@ -288,11 +288,13 @@ class Chatbot:
|
|
288 |
}
|
289 |
}
|
290 |
|
|
|
|
|
291 |
if stream:
|
292 |
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
293 |
res = await self.stream_response(res)
|
294 |
else:
|
295 |
-
res = chain.invoke(user_query=user_query_dict, config=chain_config)
|
296 |
|
297 |
answer = res.get("answer", res.get("result"))
|
298 |
|
|
|
288 |
}
|
289 |
}
|
290 |
|
291 |
+
stream = False
|
292 |
+
|
293 |
if stream:
|
294 |
res = chain.stream(user_query=user_query_dict, config=chain_config)
|
295 |
res = await self.stream_response(res)
|
296 |
else:
|
297 |
+
res = await chain.invoke(user_query=user_query_dict, config=chain_config)
|
298 |
|
299 |
answer = res.get("answer", res.get("result"))
|
300 |
|
code/modules/chat/langchain/langchain_rag.py
CHANGED
@@ -3,10 +3,84 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
3 |
from modules.chat.langchain.utils import *
|
4 |
from langchain.memory import ChatMessageHistory
|
5 |
from modules.chat.base import BaseRAG
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
-
class
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
Initialize the Langchain_RAG class.
|
12 |
|
@@ -118,7 +192,7 @@ class Langchain_RAG(BaseRAG):
|
|
118 |
) # add previous messages to the store. Note: the store is in-memory.
|
119 |
return self.store[(user_id, conversation_id)]
|
120 |
|
121 |
-
def invoke(self, user_query, config):
|
122 |
"""
|
123 |
Invoke the chain.
|
124 |
|
@@ -128,7 +202,7 @@ class Langchain_RAG(BaseRAG):
|
|
128 |
Returns:
|
129 |
dict: The output variables.
|
130 |
"""
|
131 |
-
res = self.rag_chain.
|
132 |
res["rephrase_prompt"] = self.rephrase_prompt
|
133 |
res["qa_prompt"] = self.qa_prompt
|
134 |
return res
|
|
|
3 |
from modules.chat.langchain.utils import *
|
4 |
from langchain.memory import ChatMessageHistory
|
5 |
from modules.chat.base import BaseRAG
|
6 |
+
from langchain_core.prompts import PromptTemplate
|
7 |
+
from langchain.memory import (
|
8 |
+
ConversationBufferWindowMemory,
|
9 |
+
ConversationSummaryBufferMemory,
|
10 |
+
)
|
11 |
|
12 |
|
13 |
+
class Langchain_RAG_V1(BaseRAG):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
Initialize the Langchain_RAG class.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
llm (LanguageModelLike): The language model instance.
|
23 |
+
memory (BaseChatMessageHistory): The chat message history instance.
|
24 |
+
retriever (BaseRetriever): The retriever instance.
|
25 |
+
qa_prompt (str): The QA prompt string.
|
26 |
+
rephrase_prompt (str): The rephrase prompt string.
|
27 |
+
"""
|
28 |
+
self.llm = llm
|
29 |
+
self.config = config
|
30 |
+
# self.memory = self.add_history_from_list(memory)
|
31 |
+
self.memory = ConversationBufferWindowMemory(
|
32 |
+
k=self.config["llm_params"]["memory_window"],
|
33 |
+
memory_key="chat_history",
|
34 |
+
return_messages=True,
|
35 |
+
output_key="answer",
|
36 |
+
max_token_limit=128,
|
37 |
+
)
|
38 |
+
self.retriever = retriever
|
39 |
+
self.qa_prompt = qa_prompt
|
40 |
+
self.rephrase_prompt = rephrase_prompt
|
41 |
+
self.store = {}
|
42 |
+
|
43 |
+
self.qa_prompt = PromptTemplate(
|
44 |
+
template=self.qa_prompt,
|
45 |
+
input_variables=["context", "chat_history", "input"],
|
46 |
+
)
|
47 |
+
|
48 |
+
self.rag_chain = CustomConversationalRetrievalChain.from_llm(
|
49 |
+
llm=llm,
|
50 |
+
chain_type="stuff",
|
51 |
+
retriever=retriever,
|
52 |
+
return_source_documents=True,
|
53 |
+
memory=self.memory,
|
54 |
+
combine_docs_chain_kwargs={"prompt": self.qa_prompt},
|
55 |
+
response_if_no_docs_found="No context found",
|
56 |
+
)
|
57 |
+
|
58 |
+
def add_history_from_list(self, history_list):
|
59 |
+
"""
|
60 |
+
TODO: Add messages from a list to the chat history.
|
61 |
+
"""
|
62 |
+
history = []
|
63 |
+
|
64 |
+
return history
|
65 |
+
|
66 |
+
async def invoke(self, user_query, config):
|
67 |
+
"""
|
68 |
+
Invoke the chain.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
kwargs: The input variables.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
dict: The output variables.
|
75 |
+
"""
|
76 |
+
res = await self.rag_chain.acall(user_query["input"])
|
77 |
+
return res
|
78 |
+
|
79 |
+
|
80 |
+
class Langchain_RAG_V2(BaseRAG):
|
81 |
+
def __init__(
|
82 |
+
self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict
|
83 |
+
):
|
84 |
"""
|
85 |
Initialize the Langchain_RAG class.
|
86 |
|
|
|
192 |
) # add previous messages to the store. Note: the store is in-memory.
|
193 |
return self.store[(user_id, conversation_id)]
|
194 |
|
195 |
+
async def invoke(self, user_query, config):
|
196 |
"""
|
197 |
Invoke the chain.
|
198 |
|
|
|
202 |
Returns:
|
203 |
dict: The output variables.
|
204 |
"""
|
205 |
+
res = await self.rag_chain.ainvoke(user_query, config)
|
206 |
res["rephrase_prompt"] = self.rephrase_prompt
|
207 |
res["qa_prompt"] = self.qa_prompt
|
208 |
return res
|
code/modules/chat/langchain/utils.py
CHANGED
@@ -35,6 +35,133 @@ from langchain_core.runnables.config import RunnableConfig
|
|
35 |
from langchain_core.messages import BaseMessage
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
39 |
|
40 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
@@ -69,6 +196,10 @@ class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
|
69 |
List[BaseMessage]: The last k conversations.
|
70 |
"""
|
71 |
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
|
|
|
|
|
|
|
|
72 |
messages = hist.messages.copy()
|
73 |
|
74 |
if not self.history_messages_key:
|
@@ -83,6 +214,9 @@ class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
|
83 |
|
84 |
messages = self._get_chat_history(messages)
|
85 |
|
|
|
|
|
|
|
86 |
return messages
|
87 |
|
88 |
|
@@ -103,22 +237,6 @@ class InMemoryHistory(BaseChatMessageHistory, BaseModel):
|
|
103 |
"""Return the number of messages."""
|
104 |
return len(self.messages)
|
105 |
|
106 |
-
def get_last_n_conversations(self, n: int) -> "InMemoryHistory":
|
107 |
-
"""Return a new InMemoryHistory object with the last n conversations from the message history.
|
108 |
-
|
109 |
-
Args:
|
110 |
-
n (int): The number of last conversations to return. If 0, return an empty history.
|
111 |
-
|
112 |
-
Returns:
|
113 |
-
InMemoryHistory: A new InMemoryHistory object containing the last n conversations.
|
114 |
-
"""
|
115 |
-
if n == 0:
|
116 |
-
return InMemoryHistory()
|
117 |
-
# Each conversation consists of a pair of messages (human + AI)
|
118 |
-
num_messages = n * 2
|
119 |
-
last_messages = self.messages[-num_messages:]
|
120 |
-
return InMemoryHistory(messages=last_messages)
|
121 |
-
|
122 |
|
123 |
def create_history_aware_retriever(
|
124 |
llm: LanguageModelLike,
|
|
|
35 |
from langchain_core.messages import BaseMessage
|
36 |
|
37 |
|
38 |
+
from langchain_core.output_parsers import StrOutputParser
|
39 |
+
from langchain_core.prompts import ChatPromptTemplate
|
40 |
+
from langchain_community.chat_models import ChatOpenAI
|
41 |
+
|
42 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
43 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
44 |
+
|
45 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
46 |
+
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
47 |
+
import inspect
|
48 |
+
from langchain.chains.conversational_retrieval.base import _get_chat_history
|
49 |
+
from langchain_core.messages import BaseMessage
|
50 |
+
|
51 |
+
|
52 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
53 |
+
|
54 |
+
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
55 |
+
_ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
|
56 |
+
buffer = ""
|
57 |
+
for dialogue_turn in chat_history:
|
58 |
+
if isinstance(dialogue_turn, BaseMessage):
|
59 |
+
role_prefix = _ROLE_MAP.get(
|
60 |
+
dialogue_turn.type, f"{dialogue_turn.type}: "
|
61 |
+
)
|
62 |
+
buffer += f"\n{role_prefix}{dialogue_turn.content}"
|
63 |
+
elif isinstance(dialogue_turn, tuple):
|
64 |
+
human = "Student: " + dialogue_turn[0]
|
65 |
+
ai = "AI Tutor: " + dialogue_turn[1]
|
66 |
+
buffer += "\n" + "\n".join([human, ai])
|
67 |
+
else:
|
68 |
+
raise ValueError(
|
69 |
+
f"Unsupported chat history format: {type(dialogue_turn)}."
|
70 |
+
f" Full chat history: {chat_history} "
|
71 |
+
)
|
72 |
+
return buffer
|
73 |
+
|
74 |
+
async def _acall(
|
75 |
+
self,
|
76 |
+
inputs: Dict[str, Any],
|
77 |
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
78 |
+
) -> Dict[str, Any]:
|
79 |
+
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
80 |
+
question = inputs["question"]
|
81 |
+
get_chat_history = self._get_chat_history
|
82 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
83 |
+
if chat_history_str:
|
84 |
+
# callbacks = _run_manager.get_child()
|
85 |
+
# new_question = await self.question_generator.arun(
|
86 |
+
# question=question, chat_history=chat_history_str, callbacks=callbacks
|
87 |
+
# )
|
88 |
+
system = (
|
89 |
+
"You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
|
90 |
+
"Incorporate relevant details from the chat history to make the question clearer and more specific."
|
91 |
+
"Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
|
92 |
+
"If the question is conversational and doesn't require context, do not rephrase it. "
|
93 |
+
"Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
|
94 |
+
"Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
|
95 |
+
"Chat history: \n{chat_history_str}\n"
|
96 |
+
"Rephrase the following question only if necessary: '{input}'"
|
97 |
+
)
|
98 |
+
|
99 |
+
prompt = ChatPromptTemplate.from_messages(
|
100 |
+
[
|
101 |
+
("system", system),
|
102 |
+
("human", "{input}, {chat_history_str}"),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
106 |
+
step_back = prompt | llm | StrOutputParser()
|
107 |
+
new_question = step_back.invoke(
|
108 |
+
{"input": question, "chat_history_str": chat_history_str}
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
new_question = question
|
112 |
+
accepts_run_manager = (
|
113 |
+
"run_manager" in inspect.signature(self._aget_docs).parameters
|
114 |
+
)
|
115 |
+
if accepts_run_manager:
|
116 |
+
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
|
117 |
+
else:
|
118 |
+
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
119 |
+
|
120 |
+
output: Dict[str, Any] = {}
|
121 |
+
output["original_question"] = question
|
122 |
+
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
123 |
+
output[self.output_key] = self.response_if_no_docs_found
|
124 |
+
else:
|
125 |
+
new_inputs = inputs.copy()
|
126 |
+
if self.rephrase_question:
|
127 |
+
new_inputs["question"] = new_question
|
128 |
+
new_inputs["chat_history"] = chat_history_str
|
129 |
+
|
130 |
+
# Prepare the final prompt with metadata
|
131 |
+
context = "\n\n".join(
|
132 |
+
[
|
133 |
+
f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
|
134 |
+
for idx, doc in enumerate(docs)
|
135 |
+
]
|
136 |
+
)
|
137 |
+
final_prompt = (
|
138 |
+
"You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
|
139 |
+
"If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
|
140 |
+
"Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
|
141 |
+
"Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
|
142 |
+
f"Chat History:\n{chat_history_str}\n\n"
|
143 |
+
f"Context:\n{context}\n\n"
|
144 |
+
"Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
|
145 |
+
f"Student: {input}\n"
|
146 |
+
"AI Tutor:"
|
147 |
+
)
|
148 |
+
|
149 |
+
new_inputs["input"] = final_prompt
|
150 |
+
# new_inputs["question"] = final_prompt
|
151 |
+
# output["final_prompt"] = final_prompt
|
152 |
+
|
153 |
+
answer = await self.combine_docs_chain.arun(
|
154 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
155 |
+
)
|
156 |
+
output[self.output_key] = answer
|
157 |
+
|
158 |
+
if self.return_source_documents:
|
159 |
+
output["source_documents"] = docs
|
160 |
+
output["rephrased_question"] = new_question
|
161 |
+
output["context"] = output["source_documents"]
|
162 |
+
return output
|
163 |
+
|
164 |
+
|
165 |
class CustomRunnableWithHistory(RunnableWithMessageHistory):
|
166 |
|
167 |
def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
|
|
|
196 |
List[BaseMessage]: The last k conversations.
|
197 |
"""
|
198 |
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
199 |
+
|
200 |
+
print("\n\n\n")
|
201 |
+
print("Hist: ", hist)
|
202 |
+
print("\n\n\n")
|
203 |
messages = hist.messages.copy()
|
204 |
|
205 |
if not self.history_messages_key:
|
|
|
214 |
|
215 |
messages = self._get_chat_history(messages)
|
216 |
|
217 |
+
print("\n\n\n")
|
218 |
+
print("Messages: ", messages)
|
219 |
+
print("\n\n\n")
|
220 |
return messages
|
221 |
|
222 |
|
|
|
237 |
"""Return the number of messages."""
|
238 |
return len(self.messages)
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def create_history_aware_retriever(
|
242 |
llm: LanguageModelLike,
|
code/modules/chat/llm_tutor.py
CHANGED
@@ -2,7 +2,7 @@ from modules.chat.helpers import get_prompt
|
|
2 |
from modules.chat.chat_model_loader import ChatModelLoader
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
-
from modules.chat.langchain.langchain_rag import
|
6 |
|
7 |
|
8 |
class LLMTutor:
|
@@ -103,12 +103,13 @@ class LLMTutor:
|
|
103 |
retriever = Retriever(self.config)._return_retriever(db)
|
104 |
|
105 |
if self.config["llm_params"]["llm_arch"] == "langchain":
|
106 |
-
self.qa_chain =
|
107 |
llm=llm,
|
108 |
memory=memory,
|
109 |
retriever=retriever,
|
110 |
qa_prompt=qa_prompt,
|
111 |
rephrase_prompt=rephrase_prompt,
|
|
|
112 |
)
|
113 |
else:
|
114 |
raise ValueError(
|
|
|
2 |
from modules.chat.chat_model_loader import ChatModelLoader
|
3 |
from modules.vectorstore.store_manager import VectorStoreManager
|
4 |
from modules.retriever.retriever import Retriever
|
5 |
+
from modules.chat.langchain.langchain_rag import Langchain_RAG_V1, Langchain_RAG_V2
|
6 |
|
7 |
|
8 |
class LLMTutor:
|
|
|
103 |
retriever = Retriever(self.config)._return_retriever(db)
|
104 |
|
105 |
if self.config["llm_params"]["llm_arch"] == "langchain":
|
106 |
+
self.qa_chain = Langchain_RAG_V2(
|
107 |
llm=llm,
|
108 |
memory=memory,
|
109 |
retriever=retriever,
|
110 |
qa_prompt=qa_prompt,
|
111 |
rephrase_prompt=rephrase_prompt,
|
112 |
+
config=self.config,
|
113 |
)
|
114 |
else:
|
115 |
raise ValueError(
|