nkcong206 commited on
Commit
ee39452
1 Parent(s): 65657aa
Files changed (1) hide show
  1. app.py +60 -55
app.py CHANGED
@@ -216,12 +216,33 @@ if "rag" not in st.session_state:
216
 
217
  if "llm" not in st.session_state:
218
  st.session_state.llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- if "model" not in st.session_state:
221
- st.session_state.model = None
 
 
 
 
222
 
223
  if "embd" not in st.session_state:
224
- st.session_state.embd = None
225
 
226
  if "save_dir" not in st.session_state:
227
  st.session_state.save_dir = None
@@ -246,25 +267,9 @@ if st.session_state.gemini_api is None:
246
  else:
247
  os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
248
 
249
- st.session_state.model = ChatGoogleGenerativeAI(
250
- model="gemini-1.5-flash",
251
- temperature=0,
252
- max_tokens=None,
253
- timeout=None,
254
- max_retries=2,
255
- )
256
 
257
  st.write(f"Key is set to: {st.session_state.gemini_api}")
258
- model_name="bkai-foundation-models/vietnamese-bi-encoder"
259
- model_kwargs = {'device': 'cpu'}
260
- encode_kwargs = {'normalize_embeddings': False}
261
-
262
- st.session_state.embd = HuggingFaceEmbeddings(
263
- model_name=model_name,
264
- model_kwargs=model_kwargs,
265
- encode_kwargs=encode_kwargs
266
- )
267
-
268
  st.write(f"loaded vietnamese-bi-encoder")
269
 
270
  if st.session_state.save_dir is None:
@@ -278,6 +283,8 @@ def load_txt(file_path):
278
  doc = loader_sv.load()
279
  return doc
280
 
 
 
281
  with st.sidebar:
282
  uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
283
  if st.session_state.gemini_api:
@@ -306,45 +313,43 @@ with st.sidebar:
306
  else:
307
  st.session_state.uploaded_files = set()
308
  st.session_state.rag = None
309
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  if st.session_state.uploaded_files:
311
  if st.session_state.gemini_api is not None:
312
  with st.spinner("Đang xử lý, vui lòng đợi..."):
313
  if st.session_state.rag is None:
314
- docs_texts = [d.page_content for d in documents]
315
-
316
- results = recursive_embed_cluster_summarize(st.session_state.model, st.session_state.embd, docs_texts, level=1, n_levels=3)
317
-
318
- all_texts = docs_texts.copy()
319
-
320
- for level in sorted(results.keys()):
321
- summaries = results[level][1]["summaries"].tolist()
322
- all_texts.extend(summaries)
323
-
324
- vectorstore = Chroma.from_texts(texts=all_texts, embedding=st.session_state.embd)
325
-
326
- retriever = vectorstore.as_retriever()
327
-
328
- def format_docs(docs):
329
- return "\n\n".join(doc.page_content for doc in docs)
330
-
331
- template = """
332
- Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
333
- Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
334
- Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
335
- Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
336
- {context}\n
337
- hãy trả lời:\n
338
- {question}
339
- """
340
- prompt = PromptTemplate(template = template, input_variables=["context", "question"])
341
- rag_chain = (
342
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
343
- | prompt
344
- | st.session_state.model
345
- | StrOutputParser()
346
- )
347
- st.session_state.rag = rag_chain
348
 
349
  if st.session_state.gemini_api is not None:
350
  if st.session_state.llm is None:
 
216
 
217
  if "llm" not in st.session_state:
218
  st.session_state.llm = None
219
+
220
+ @st.cache_resource
221
+ def get_chat_google_model(api_key):
222
+ os.environ["GOOGLE_API_KEY"] = api_key
223
+ return ChatGoogleGenerativeAI(
224
+ model="gemini-1.5-flash",
225
+ temperature=0,
226
+ max_tokens=None,
227
+ timeout=None,
228
+ max_retries=2,
229
+ )
230
+
231
+ @st.cache_resource
232
+ def get_embedding_model():
233
+ model_name = "bkai-foundation-models/vietnamese-bi-encoder"
234
+ model_kwargs = {'device': 'cpu'}
235
+ encode_kwargs = {'normalize_embeddings': False}
236
 
237
+ model = HuggingFaceEmbeddings(
238
+ model_name=model_name,
239
+ model_kwargs=model_kwargs,
240
+ encode_kwargs=encode_kwargs
241
+ )
242
+ return model
243
 
244
  if "embd" not in st.session_state:
245
+ st.session_state.embd = get_embedding_model()
246
 
247
  if "save_dir" not in st.session_state:
248
  st.session_state.save_dir = None
 
267
  else:
268
  os.environ["GOOGLE_API_KEY"] = st.session_state.gemini_api
269
 
270
+ st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
 
 
 
 
 
 
271
 
272
  st.write(f"Key is set to: {st.session_state.gemini_api}")
 
 
 
 
 
 
 
 
 
 
273
  st.write(f"loaded vietnamese-bi-encoder")
274
 
275
  if st.session_state.save_dir is None:
 
283
  doc = loader_sv.load()
284
  return doc
285
 
286
+
287
+
288
  with st.sidebar:
289
  uploaded_files = st.file_uploader("Chọn file CSV", accept_multiple_files=True, type=["txt"])
290
  if st.session_state.gemini_api:
 
313
  else:
314
  st.session_state.uploaded_files = set()
315
  st.session_state.rag = None
316
+
317
+ def format_docs(docs):
318
+ return "\n\n".join(doc.page_content for doc in docs)
319
+
320
+ @st.cache_resource
321
+ def compute_rag_chain(model, embd, docs_texts):
322
+ results = recursive_embed_cluster_summarize(model, embd, docs_texts, level=1, n_levels=3)
323
+ all_texts = docs_texts.copy()
324
+ for level in sorted(results.keys()):
325
+ summaries = results[level][1]["summaries"].tolist()
326
+ all_texts.extend(summaries)
327
+ vectorstore = Chroma.from_texts(texts=all_texts, embedding=embd)
328
+ retriever = vectorstore.as_retriever()
329
+ template = """
330
+ Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
331
+ Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
332
+ Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n
333
+ Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n
334
+ {context}\n
335
+ hãy trả lời:\n
336
+ {question}
337
+ """
338
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
339
+ rag_chain = (
340
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
341
+ | prompt
342
+ | model
343
+ | StrOutputParser()
344
+ )
345
+ return rag_chain
346
+
347
  if st.session_state.uploaded_files:
348
  if st.session_state.gemini_api is not None:
349
  with st.spinner("Đang xử lý, vui lòng đợi..."):
350
  if st.session_state.rag is None:
351
+ docs_texts = [d.page_content for d in documents]
352
+ st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
  if st.session_state.gemini_api is not None:
355
  if st.session_state.llm is None: