Carlosito16 commited on
Commit
b671335
1 Parent(s): b68aade

return question_generator from the function to check history chat

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -104,7 +104,7 @@ def load_llm_model():
104
 
105
  @st.cache_resource
106
  def load_conversational_qa_memory_retriever():
107
- global question_generator
108
  question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
109
  doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
110
  memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
@@ -118,7 +118,7 @@ def load_conversational_qa_memory_retriever():
118
  return_source_documents=True,
119
  memory = memory,
120
  get_chat_history=lambda h :h)
121
- return conversational_qa_memory_retriever
122
 
123
  def load_retriever(llm, db):
124
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
@@ -232,7 +232,7 @@ embedding_model = load_embedding_model()
232
  vector_database = load_faiss_index()
233
  llm_model = load_llm_model()
234
  qa_retriever = load_retriever(llm= llm_model, db= vector_database)
235
- conversational_qa_memory_retriever = load_conversational_qa_memory_retriever()
236
  print("all load done")
237
 
238
  #Addional things for Conversation flows
 
104
 
105
  @st.cache_resource
106
  def load_conversational_qa_memory_retriever():
107
+
108
  question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
109
  doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
110
  memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
 
118
  return_source_documents=True,
119
  memory = memory,
120
  get_chat_history=lambda h :h)
121
+ return conversational_qa_memory_retriever, question_generator
122
 
123
  def load_retriever(llm, db):
124
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
 
232
  vector_database = load_faiss_index()
233
  llm_model = load_llm_model()
234
  qa_retriever = load_retriever(llm= llm_model, db= vector_database)
235
+ conversational_qa_memory_retriever, question_generator = load_conversational_qa_memory_retriever()
236
  print("all load done")
237
 
238
  #Addional things for Conversation flows