Spaces:
Running
Running
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate,MessagesPlaceholder | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers import EnsembleRetriever | |
from langchain_community.vectorstores import FAISS | |
from langchain_groq import ChatGroq | |
from langchain import hub | |
import pickle | |
import os | |
import gradio as gr | |
import spaces | |
GROQ_API_KEY="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" | |
os.environ['GROQ_API_KEY'] = GROQ_API_KEY | |
embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-multilingual-base", model_kwargs={"trust_remote_code":True, "device": "cuda"}) | |
llm = ChatGroq( | |
model="llama-3.1-8b-instant", | |
temperature=0.0, | |
max_tokens=1024, | |
max_retries=2 | |
) | |
excel_vectorstore = FAISS.load_local(folder_path="./faiss_excel_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True) | |
word_vectorstore = FAISS.load_local(folder_path="./faiss_recursive_split_word_doc_index", embeddings=embed_model, allow_dangerous_deserialization=True) | |
excel_vectorstore.merge_from(word_vectorstore) | |
combined_vectorstore = excel_vectorstore | |
with open('combined_recursive_keyword_retriever.pkl', 'rb') as f: | |
combined_keyword_retriever = pickle.load(f) | |
combined_keyword_retriever.k = 10 | |
semantic_retriever = combined_vectorstore.as_retriever(search_type="mmr", search_kwargs={'k': 10, 'lambda_mult': 0.25}) | |
# initialize the ensemble retriever | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[combined_keyword_retriever, semantic_retriever], weights=[0.5, 0.5] | |
) | |
embeddings_filter = EmbeddingsFilter(embeddings=embed_model, similarity_threshold=0.6) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=embeddings_filter, base_retriever=ensemble_retriever | |
) | |
template = """ | |
User: You are an AI Assistant that follows instructions extremely well. | |
Please be truthful and give direct answers. Please tell 'I don't know' if user query is not in CONTEXT | |
Keep in mind, you will lose the job, if you answer out of CONTEXT questions | |
CONTEXT: {context} | |
Query: {question} | |
Remember only return AI answer | |
Assistant: | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
output_parser = StrOutputParser() | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = ( | |
{"context": compression_retriever.with_config(run_name="Docs") | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| output_parser | |
) | |
# zero = torch.Tensor([0]).cuda() | |
# @spaces.GPU | |
def get_response(question, history): | |
# print(question) | |
curr_ans = "" | |
for chunk in rag_chain.stream(question): | |
curr_ans += chunk | |
yield curr_ans | |
example_questions = [ | |
"الموسم المناسب لزراعة الذرة العلفية ؟", | |
"ما هي الاحتياجات المائية لتربية الحيوانات؟", | |
"ما هي خطوات إنتاج الشتلات؟", | |
"الموسم المناسب لزراعة الطماطم في الحقل المكشوف بدولة الإمارات؟", | |
"شروط اختيار مكان منحل العسل؟", | |
"ما هو تقييم مطعم قصر نجد؟", | |
"ما كمية أعلاف الجت المستلمة في منطقة الظفرة عام 2022", | |
"ما مساحات المزارع المروية بالتنقيط في منطقة الرحبة عام 2020", | |
"في إمارة أبوظبي في عام 2022، هل نسبة العينات الغذائية الغير مطابقة من إجمالي العينات أعلى في العينات المحلية أم العينات المستوردة" | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# ADAFSA RAG Chatbot Demo | |
""" | |
) | |
chatbot = gr.Chatbot(placeholder="<strong>ADAFSA-RAG Chatbot</strong>") | |
gr.ChatInterface( | |
title="", | |
fn=get_response, | |
chatbot=chatbot, | |
examples=example_questions, | |
) | |
demo.launch() | |