Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from pathlib import Path | |
from llama_index.core.selectors import LLMSingleSelector | |
from llama_index.core.tools import QueryEngineTool | |
from llama_index.core import VectorStoreIndex | |
from llama_index.core import Settings | |
from llama_index.core import SimpleDirectoryReader | |
from llama_index.llms.groq import Groq | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from typing import Tuple | |
from llama_index.core import StorageContext, load_index_from_storage | |
from llama_index.core.objects import ObjectIndex | |
from llama_index.core.agent import ReActAgent | |
import time | |
import sys | |
import io | |
# Function to process files and create document tools | |
def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]: | |
documents = SimpleDirectoryReader(input_files=[document_fp]).load_data() | |
Settings.llm = Groq(model="mixtral-8x7b-32768") | |
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5") | |
load_dir_path = f"/home/user/app/agentic_index/{doc_name}" | |
storage_context = StorageContext.from_defaults(persist_dir=load_dir_path) | |
vector_index = load_index_from_storage(storage_context) | |
vector_query_engine = vector_index.as_query_engine() | |
vector_tool = QueryEngineTool.from_defaults( | |
name=f"{doc_name}_vector_query_engine_tool", | |
query_engine=vector_query_engine, | |
description=f"Useful for retrieving specific context from the {doc_name}.", | |
) | |
return vector_tool | |
# Function to find and sort .tex files | |
def find_tex_files(directory: str): | |
tex_files = [] | |
for root, dirs, files in os.walk(directory): | |
for file in files: | |
if file.endswith(('.tex', '.txt')): | |
file_path = os.path.abspath(os.path.join(root, file)) | |
tex_files.append(file_path) | |
tex_files.sort() | |
return tex_files | |
# Main app function | |
def main(): | |
st.title("AMGPT, Powered by LlamaIndex") | |
# API Key input | |
apikey = st.text_input("Enter your Groq API Key", type="password") | |
os.environ["GROQ_API_KEY"] = apikey | |
llm = Groq(model="mixtral-8x7b-32768") | |
with st.sidebar: | |
verbose_toggle = st.toggle("Verbose") # get verbose or only LLM response | |
reset = st.button('Reset Chat!') # reset the chat | |
if apikey: | |
if "tools_loaded" not in st.session_state: | |
try: | |
directory = '/home/user/app/rag_docs_final_review_tex_merged' | |
tex_files = find_tex_files(directory) | |
with st.spinner('Please wait, AMGPT is loading....'): | |
paper_to_tools_dict = {} | |
for paper in tex_files: | |
path = Path(paper) | |
vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path) | |
paper_to_tools_dict[path.stem] = [vector_tool] | |
initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]] | |
obj_index = ObjectIndex.from_objects( | |
initial_tools, | |
index_cls=VectorStoreIndex, | |
) | |
obj_retriever = obj_index.as_retriever(similarity_top_k=6) | |
context = """You are an agent designed to answer scientific queries over a set of given documents. | |
Please always use the tools provided to answer a question. Do not rely on prior knowledge. | |
""" | |
agent = ReActAgent.from_tools( | |
tool_retriever=obj_retriever, | |
llm=llm, | |
verbose=True, | |
context=context | |
) | |
st.success('Done!, you may start asking questions now') | |
# store session state variables | |
st.session_state["tools_loaded"] = True | |
st.session_state["agent"] = agent | |
except Exception as e: | |
st.error(e) | |
if "messages" not in st.session_state or reset==True: | |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if prompt := st.chat_input(): | |
# if the user started chatting without setting the OPENAI API KEY | |
if not apikey: | |
st.info("Please add your Groq API key to continue.") | |
st.stop() | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.chat_message("user").write(prompt) | |
try: | |
with st.spinner('Wait for output...'): | |
# Redirect stdout | |
original_stdout = sys.stdout | |
sys.stdout = io.StringIO() | |
# query the agent | |
response = st.session_state.agent.query(prompt) | |
# Get the captured output and restore stdout | |
output = sys.stdout.getvalue() | |
sys.stdout = original_stdout | |
# format the received verbose output | |
verbose = '' | |
for output_string in output.split('==='): | |
verbose += output_string | |
verbose += '\n' | |
# assistant response | |
msg = f'{verbose}' if verbose_toggle else f'{response.response[:]}' | |
# write the response | |
st.session_state.messages.append({"role": "assistant", "content": msg}) | |
st.chat_message("assistant").markdown(msg) | |
except Exception as e: | |
st.error(e) | |
if __name__ == "__main__": | |
main() | |