AMGPT3 / app.py
achuthc1298's picture
Update app.py
b344204 verified
raw
history blame
No virus
5.83 kB
import streamlit as st
import os
from pathlib import Path
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.core import VectorStoreIndex
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from typing import Tuple
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.objects import ObjectIndex
from llama_index.core.agent import ReActAgent
import time
import sys
import io
# Function to process files and create document tools
def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]:
documents = SimpleDirectoryReader(input_files=[document_fp]).load_data()
Settings.llm = Groq(model="mixtral-8x7b-32768")
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
load_dir_path = f"/home/user/app/agentic_index/{doc_name}"
storage_context = StorageContext.from_defaults(persist_dir=load_dir_path)
vector_index = load_index_from_storage(storage_context)
vector_query_engine = vector_index.as_query_engine()
vector_tool = QueryEngineTool.from_defaults(
name=f"{doc_name}_vector_query_engine_tool",
query_engine=vector_query_engine,
description=f"Useful for retrieving specific context from the {doc_name}.",
)
return vector_tool
# Function to find and sort .tex files
def find_tex_files(directory: str):
tex_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(('.tex', '.txt')):
file_path = os.path.abspath(os.path.join(root, file))
tex_files.append(file_path)
tex_files.sort()
return tex_files
# Main app function
def main():
st.title("AMGPT, Powered by LlamaIndex")
# API Key input
apikey = st.text_input("Enter your Groq API Key", type="password")
os.environ["GROQ_API_KEY"] = apikey
llm = Groq(model="mixtral-8x7b-32768")
with st.sidebar:
verbose_toggle = st.toggle("Verbose") # get verbose or only LLM response
reset = st.button('Reset Chat!') # reset the chat
if apikey:
if "tools_loaded" not in st.session_state:
try:
directory = '/home/user/app/rag_docs_final_review_tex_merged'
tex_files = find_tex_files(directory)
with st.spinner('Please wait, AMGPT is loading....'):
paper_to_tools_dict = {}
for paper in tex_files:
path = Path(paper)
vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path)
paper_to_tools_dict[path.stem] = [vector_tool]
initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]]
obj_index = ObjectIndex.from_objects(
initial_tools,
index_cls=VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=6)
context = """You are an agent designed to answer scientific queries over a set of given documents.
Please always use the tools provided to answer a question. Do not rely on prior knowledge.
"""
agent = ReActAgent.from_tools(
tool_retriever=obj_retriever,
llm=llm,
verbose=True,
context=context
)
st.success('Done!, you may start asking questions now')
# store session state variables
st.session_state["tools_loaded"] = True
st.session_state["agent"] = agent
except Exception as e:
st.error(e)
if "messages" not in st.session_state or reset==True:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if prompt := st.chat_input():
# if the user started chatting without setting the OPENAI API KEY
if not apikey:
st.info("Please add your Groq API key to continue.")
st.stop()
st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
try:
with st.spinner('Wait for output...'):
# Redirect stdout
original_stdout = sys.stdout
sys.stdout = io.StringIO()
# query the agent
response = st.session_state.agent.query(prompt)
# Get the captured output and restore stdout
output = sys.stdout.getvalue()
sys.stdout = original_stdout
# format the received verbose output
verbose = ''
for output_string in output.split('==='):
verbose += output_string
verbose += '\n'
# assistant response
msg = f'{verbose}' if verbose_toggle else f'{response.response[:]}'
# write the response
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").markdown(msg)
except Exception as e:
st.error(e)
if __name__ == "__main__":
main()