XThomasBU commited on
Commit
e165ea5
1 Parent(s): 9c5bb39
code/main.py CHANGED
@@ -288,11 +288,13 @@ class Chatbot:
288
  }
289
  }
290
 
 
 
291
  if stream:
292
  res = chain.stream(user_query=user_query_dict, config=chain_config)
293
  res = await self.stream_response(res)
294
  else:
295
- res = chain.invoke(user_query=user_query_dict, config=chain_config)
296
 
297
  answer = res.get("answer", res.get("result"))
298
 
 
288
  }
289
  }
290
 
291
+ stream = False
292
+
293
  if stream:
294
  res = chain.stream(user_query=user_query_dict, config=chain_config)
295
  res = await self.stream_response(res)
296
  else:
297
+ res = await chain.invoke(user_query=user_query_dict, config=chain_config)
298
 
299
  answer = res.get("answer", res.get("result"))
300
 
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -3,10 +3,84 @@ from langchain_core.prompts import ChatPromptTemplate
3
  from modules.chat.langchain.utils import *
4
  from langchain.memory import ChatMessageHistory
5
  from modules.chat.base import BaseRAG
 
 
 
 
 
6
 
7
 
8
- class Langchain_RAG(BaseRAG):
9
- def __init__(self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
  Initialize the Langchain_RAG class.
12
 
@@ -118,7 +192,7 @@ class Langchain_RAG(BaseRAG):
118
  ) # add previous messages to the store. Note: the store is in-memory.
119
  return self.store[(user_id, conversation_id)]
120
 
121
- def invoke(self, user_query, config):
122
  """
123
  Invoke the chain.
124
 
@@ -128,7 +202,7 @@ class Langchain_RAG(BaseRAG):
128
  Returns:
129
  dict: The output variables.
130
  """
131
- res = self.rag_chain.invoke(user_query, config)
132
  res["rephrase_prompt"] = self.rephrase_prompt
133
  res["qa_prompt"] = self.qa_prompt
134
  return res
 
3
  from modules.chat.langchain.utils import *
4
  from langchain.memory import ChatMessageHistory
5
  from modules.chat.base import BaseRAG
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain.memory import (
8
+ ConversationBufferWindowMemory,
9
+ ConversationSummaryBufferMemory,
10
+ )
11
 
12
 
13
+ class Langchain_RAG_V1(BaseRAG):
14
+
15
+ def __init__(
16
+ self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict
17
+ ):
18
+ """
19
+ Initialize the Langchain_RAG class.
20
+
21
+ Args:
22
+ llm (LanguageModelLike): The language model instance.
23
+ memory (BaseChatMessageHistory): The chat message history instance.
24
+ retriever (BaseRetriever): The retriever instance.
25
+ qa_prompt (str): The QA prompt string.
26
+ rephrase_prompt (str): The rephrase prompt string.
27
+ """
28
+ self.llm = llm
29
+ self.config = config
30
+ # self.memory = self.add_history_from_list(memory)
31
+ self.memory = ConversationBufferWindowMemory(
32
+ k=self.config["llm_params"]["memory_window"],
33
+ memory_key="chat_history",
34
+ return_messages=True,
35
+ output_key="answer",
36
+ max_token_limit=128,
37
+ )
38
+ self.retriever = retriever
39
+ self.qa_prompt = qa_prompt
40
+ self.rephrase_prompt = rephrase_prompt
41
+ self.store = {}
42
+
43
+ self.qa_prompt = PromptTemplate(
44
+ template=self.qa_prompt,
45
+ input_variables=["context", "chat_history", "input"],
46
+ )
47
+
48
+ self.rag_chain = CustomConversationalRetrievalChain.from_llm(
49
+ llm=llm,
50
+ chain_type="stuff",
51
+ retriever=retriever,
52
+ return_source_documents=True,
53
+ memory=self.memory,
54
+ combine_docs_chain_kwargs={"prompt": self.qa_prompt},
55
+ response_if_no_docs_found="No context found",
56
+ )
57
+
58
+ def add_history_from_list(self, history_list):
59
+ """
60
+ TODO: Add messages from a list to the chat history.
61
+ """
62
+ history = []
63
+
64
+ return history
65
+
66
+ async def invoke(self, user_query, config):
67
+ """
68
+ Invoke the chain.
69
+
70
+ Args:
71
+ kwargs: The input variables.
72
+
73
+ Returns:
74
+ dict: The output variables.
75
+ """
76
+ res = await self.rag_chain.acall(user_query["input"])
77
+ return res
78
+
79
+
80
+ class Langchain_RAG_V2(BaseRAG):
81
+ def __init__(
82
+ self, llm, memory, retriever, qa_prompt: str, rephrase_prompt: str, config: dict
83
+ ):
84
  """
85
  Initialize the Langchain_RAG class.
86
 
 
192
  ) # add previous messages to the store. Note: the store is in-memory.
193
  return self.store[(user_id, conversation_id)]
194
 
195
+ async def invoke(self, user_query, config):
196
  """
197
  Invoke the chain.
198
 
 
202
  Returns:
203
  dict: The output variables.
204
  """
205
+ res = await self.rag_chain.ainvoke(user_query, config)
206
  res["rephrase_prompt"] = self.rephrase_prompt
207
  res["qa_prompt"] = self.qa_prompt
208
  return res
code/modules/chat/langchain/utils.py CHANGED
@@ -35,6 +35,133 @@ from langchain_core.runnables.config import RunnableConfig
35
  from langchain_core.messages import BaseMessage
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
39
 
40
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
@@ -69,6 +196,10 @@ class CustomRunnableWithHistory(RunnableWithMessageHistory):
69
  List[BaseMessage]: The last k conversations.
70
  """
71
  hist: BaseChatMessageHistory = config["configurable"]["message_history"]
 
 
 
 
72
  messages = hist.messages.copy()
73
 
74
  if not self.history_messages_key:
@@ -83,6 +214,9 @@ class CustomRunnableWithHistory(RunnableWithMessageHistory):
83
 
84
  messages = self._get_chat_history(messages)
85
 
 
 
 
86
  return messages
87
 
88
 
@@ -103,22 +237,6 @@ class InMemoryHistory(BaseChatMessageHistory, BaseModel):
103
  """Return the number of messages."""
104
  return len(self.messages)
105
 
106
- def get_last_n_conversations(self, n: int) -> "InMemoryHistory":
107
- """Return a new InMemoryHistory object with the last n conversations from the message history.
108
-
109
- Args:
110
- n (int): The number of last conversations to return. If 0, return an empty history.
111
-
112
- Returns:
113
- InMemoryHistory: A new InMemoryHistory object containing the last n conversations.
114
- """
115
- if n == 0:
116
- return InMemoryHistory()
117
- # Each conversation consists of a pair of messages (human + AI)
118
- num_messages = n * 2
119
- last_messages = self.messages[-num_messages:]
120
- return InMemoryHistory(messages=last_messages)
121
-
122
 
123
  def create_history_aware_retriever(
124
  llm: LanguageModelLike,
 
35
  from langchain_core.messages import BaseMessage
36
 
37
 
38
+ from langchain_core.output_parsers import StrOutputParser
39
+ from langchain_core.prompts import ChatPromptTemplate
40
+ from langchain_community.chat_models import ChatOpenAI
41
+
42
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
43
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
44
+
45
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
46
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
47
+ import inspect
48
+ from langchain.chains.conversational_retrieval.base import _get_chat_history
49
+ from langchain_core.messages import BaseMessage
50
+
51
+
52
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
53
+
54
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
55
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
56
+ buffer = ""
57
+ for dialogue_turn in chat_history:
58
+ if isinstance(dialogue_turn, BaseMessage):
59
+ role_prefix = _ROLE_MAP.get(
60
+ dialogue_turn.type, f"{dialogue_turn.type}: "
61
+ )
62
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
63
+ elif isinstance(dialogue_turn, tuple):
64
+ human = "Student: " + dialogue_turn[0]
65
+ ai = "AI Tutor: " + dialogue_turn[1]
66
+ buffer += "\n" + "\n".join([human, ai])
67
+ else:
68
+ raise ValueError(
69
+ f"Unsupported chat history format: {type(dialogue_turn)}."
70
+ f" Full chat history: {chat_history} "
71
+ )
72
+ return buffer
73
+
74
+ async def _acall(
75
+ self,
76
+ inputs: Dict[str, Any],
77
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
78
+ ) -> Dict[str, Any]:
79
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
80
+ question = inputs["question"]
81
+ get_chat_history = self._get_chat_history
82
+ chat_history_str = get_chat_history(inputs["chat_history"])
83
+ if chat_history_str:
84
+ # callbacks = _run_manager.get_child()
85
+ # new_question = await self.question_generator.arun(
86
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
87
+ # )
88
+ system = (
89
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
90
+ "Incorporate relevant details from the chat history to make the question clearer and more specific."
91
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
92
+ "If the question is conversational and doesn't require context, do not rephrase it. "
93
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
94
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
95
+ "Chat history: \n{chat_history_str}\n"
96
+ "Rephrase the following question only if necessary: '{input}'"
97
+ )
98
+
99
+ prompt = ChatPromptTemplate.from_messages(
100
+ [
101
+ ("system", system),
102
+ ("human", "{input}, {chat_history_str}"),
103
+ ]
104
+ )
105
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
106
+ step_back = prompt | llm | StrOutputParser()
107
+ new_question = step_back.invoke(
108
+ {"input": question, "chat_history_str": chat_history_str}
109
+ )
110
+ else:
111
+ new_question = question
112
+ accepts_run_manager = (
113
+ "run_manager" in inspect.signature(self._aget_docs).parameters
114
+ )
115
+ if accepts_run_manager:
116
+ docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
117
+ else:
118
+ docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
119
+
120
+ output: Dict[str, Any] = {}
121
+ output["original_question"] = question
122
+ if self.response_if_no_docs_found is not None and len(docs) == 0:
123
+ output[self.output_key] = self.response_if_no_docs_found
124
+ else:
125
+ new_inputs = inputs.copy()
126
+ if self.rephrase_question:
127
+ new_inputs["question"] = new_question
128
+ new_inputs["chat_history"] = chat_history_str
129
+
130
+ # Prepare the final prompt with metadata
131
+ context = "\n\n".join(
132
+ [
133
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
134
+ for idx, doc in enumerate(docs)
135
+ ]
136
+ )
137
+ final_prompt = (
138
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
139
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
140
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
141
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
142
+ f"Chat History:\n{chat_history_str}\n\n"
143
+ f"Context:\n{context}\n\n"
144
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
145
+ f"Student: {input}\n"
146
+ "AI Tutor:"
147
+ )
148
+
149
+ new_inputs["input"] = final_prompt
150
+ # new_inputs["question"] = final_prompt
151
+ # output["final_prompt"] = final_prompt
152
+
153
+ answer = await self.combine_docs_chain.arun(
154
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
155
+ )
156
+ output[self.output_key] = answer
157
+
158
+ if self.return_source_documents:
159
+ output["source_documents"] = docs
160
+ output["rephrased_question"] = new_question
161
+ output["context"] = output["source_documents"]
162
+ return output
163
+
164
+
165
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
166
 
167
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
 
196
  List[BaseMessage]: The last k conversations.
197
  """
198
  hist: BaseChatMessageHistory = config["configurable"]["message_history"]
199
+
200
+ print("\n\n\n")
201
+ print("Hist: ", hist)
202
+ print("\n\n\n")
203
  messages = hist.messages.copy()
204
 
205
  if not self.history_messages_key:
 
214
 
215
  messages = self._get_chat_history(messages)
216
 
217
+ print("\n\n\n")
218
+ print("Messages: ", messages)
219
+ print("\n\n\n")
220
  return messages
221
 
222
 
 
237
  """Return the number of messages."""
238
  return len(self.messages)
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def create_history_aware_retriever(
242
  llm: LanguageModelLike,
code/modules/chat/llm_tutor.py CHANGED
@@ -2,7 +2,7 @@ from modules.chat.helpers import get_prompt
2
  from modules.chat.chat_model_loader import ChatModelLoader
3
  from modules.vectorstore.store_manager import VectorStoreManager
4
  from modules.retriever.retriever import Retriever
5
- from modules.chat.langchain.langchain_rag import Langchain_RAG
6
 
7
 
8
  class LLMTutor:
@@ -103,12 +103,13 @@ class LLMTutor:
103
  retriever = Retriever(self.config)._return_retriever(db)
104
 
105
  if self.config["llm_params"]["llm_arch"] == "langchain":
106
- self.qa_chain = Langchain_RAG(
107
  llm=llm,
108
  memory=memory,
109
  retriever=retriever,
110
  qa_prompt=qa_prompt,
111
  rephrase_prompt=rephrase_prompt,
 
112
  )
113
  else:
114
  raise ValueError(
 
2
  from modules.chat.chat_model_loader import ChatModelLoader
3
  from modules.vectorstore.store_manager import VectorStoreManager
4
  from modules.retriever.retriever import Retriever
5
+ from modules.chat.langchain.langchain_rag import Langchain_RAG_V1, Langchain_RAG_V2
6
 
7
 
8
  class LLMTutor:
 
103
  retriever = Retriever(self.config)._return_retriever(db)
104
 
105
  if self.config["llm_params"]["llm_arch"] == "langchain":
106
+ self.qa_chain = Langchain_RAG_V2(
107
  llm=llm,
108
  memory=memory,
109
  retriever=retriever,
110
  qa_prompt=qa_prompt,
111
  rephrase_prompt=rephrase_prompt,
112
+ config=self.config,
113
  )
114
  else:
115
  raise ValueError(