nkcong206 commited on
Commit
8b54974
1 Parent(s): 9dad315
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -249,9 +249,6 @@ if "save_dir" not in st.session_state:
249
 
250
  if "uploaded_files" not in st.session_state:
251
  st.session_state.uploaded_files = set()
252
-
253
- if "processing" not in st.session_state:
254
- st.session_state.processing = False
255
 
256
  @st.dialog("Setup Gemini")
257
  def vote():
@@ -261,7 +258,7 @@ def vote():
261
  """
262
  )
263
  key = st.text_input("Key:", "")
264
- if st.button("Save"):
265
  st.session_state.gemini_api = key
266
  st.rerun()
267
 
@@ -269,7 +266,6 @@ if st.session_state.gemini_api is None:
269
  vote()
270
  else:
271
  os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
272
-
273
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
274
 
275
  if st.session_state.save_dir is None:
@@ -284,7 +280,6 @@ def load_txt(file_path):
284
  return doc
285
 
286
 
287
-
288
  with st.sidebar:
289
  uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
290
  if st.session_state.gemini_api:
@@ -316,7 +311,7 @@ with st.sidebar:
316
 
317
  def format_docs(docs):
318
  return "\n\n".join(doc.page_content for doc in docs)
319
-
320
  @st.cache_resource
321
  def compute_rag_chain(_model, _embd, docs_texts):
322
  results = recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
@@ -344,15 +339,18 @@ def compute_rag_chain(_model, _embd, docs_texts):
344
  )
345
  return rag_chain
346
 
347
- if st.session_state.uploaded_files:
 
 
 
 
 
 
348
  if st.session_state.gemini_api is not None:
349
- if st.session_state.rag is None:
350
- st.session_state.processing = True
351
- docs_texts = [d.page_content for d in documents]
352
- st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
353
- st.session_state.processing = False
354
-
355
- if st.session_state.gemini_api is not None:
356
  if st.session_state.llm is None:
357
  mess = ChatPromptTemplate.from_messages(
358
  [
@@ -373,23 +371,30 @@ for message in st.session_state.chat_history:
373
  with st.chat_message(message["role"]):
374
  st.write(message["content"])
375
 
376
-
377
- if not st.session_state.processing:
378
- if st.session_state.gemini_api:
379
- if prompt := st.chat_input("Bạn muốn hỏi gì?"):
380
- st.session_state.chat_history.append({"role": "user", "content": prompt})
381
-
382
- with st.chat_message("user"):
383
- st.write(prompt)
384
-
385
- with st.chat_message("assistant"):
386
- if st.session_state.rag is not None:
387
  respone = st.session_state.rag.invoke(prompt)
388
  st.write(respone)
389
- else:
 
 
 
 
390
  ans = st.session_state.llm.invoke(prompt)
391
  respone = ans.content
392
  st.write(respone)
393
-
394
- st.session_state.chat_history.append({"role": "assistant", "content": respone})
395
-
 
 
 
 
249
 
250
  if "uploaded_files" not in st.session_state:
251
  st.session_state.uploaded_files = set()
 
 
 
252
 
253
  @st.dialog("Setup Gemini")
254
  def vote():
 
258
  """
259
  )
260
  key = st.text_input("Key:", "")
261
+ if st.button("Save") and key != "":
262
  st.session_state.gemini_api = key
263
  st.rerun()
264
 
 
266
  vote()
267
  else:
268
  os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
 
269
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
270
 
271
  if st.session_state.save_dir is None:
 
280
  return doc
281
 
282
 
 
283
  with st.sidebar:
284
  uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
285
  if st.session_state.gemini_api:
 
311
 
312
  def format_docs(docs):
313
  return "\n\n".join(doc.page_content for doc in docs)
314
+
315
  @st.cache_resource
316
  def compute_rag_chain(_model, _embd, docs_texts):
317
  results = recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
 
339
  )
340
  return rag_chain
341
 
342
+ @st.dialog("Setup RAG")
343
+ def load_rag():
344
+ docs_texts = [d.page_content for d in documents]
345
+ st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
346
+ st.rerun()
347
+
348
+ if st.session_state.uploaded_files and st.session_state.gemini_api:
349
  if st.session_state.gemini_api is not None:
350
+ if st.session_state.rag is None:
351
+ load_rag()
352
+
353
+ if st.session_state.gemini_api is not None and st.session_state.gemini_api:
 
 
 
354
  if st.session_state.llm is None:
355
  mess = ChatPromptTemplate.from_messages(
356
  [
 
371
  with st.chat_message(message["role"]):
372
  st.write(message["content"])
373
 
374
+ prompt = st.chat_input("Bạn muốn hỏi gì?")
375
+ if st.session_state.gemini_api:
376
+ if prompt:
377
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
378
+
379
+ with st.chat_message("user"):
380
+ st.write(prompt)
381
+
382
+ with st.chat_message("assistant"):
383
+ if st.session_state.rag is not None:
384
+ try:
385
  respone = st.session_state.rag.invoke(prompt)
386
  st.write(respone)
387
+ except:
388
+ respone = "Lỗi Gemini, load lại trang và nhập lại key"
389
+ st.write(respone)
390
+ else:
391
+ try:
392
  ans = st.session_state.llm.invoke(prompt)
393
  respone = ans.content
394
  st.write(respone)
395
+ except:
396
+ respone = "Lỗi Gemini, load lại trang và nhập lại key"
397
+ st.write(respone)
398
+
399
+ st.session_state.chat_history.append({"role": "assistant", "content": respone})
400
+