File size: 3,670 Bytes
08ef1eb
 
 
 
 
 
 
ce4e6ee
08ef1eb
 
 
 
 
 
 
 
ce4e6ee
0a6edbc
08ef1eb
 
 
 
 
ce4e6ee
08ef1eb
 
 
 
 
 
 
ce4e6ee
08ef1eb
 
 
 
ce4e6ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a6edbc
ce4e6ee
 
 
 
 
 
 
 
08ef1eb
ce4e6ee
08ef1eb
ce4e6ee
08ef1eb
ce4e6ee
 
 
08ef1eb
ce4e6ee
 
 
 
 
08ef1eb
 
ce4e6ee
 
 
 
 
 
08ef1eb
 
 
ce4e6ee
 
08ef1eb
ce4e6ee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
import os
from pathlib import Path
from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool
from llama_index.core import SummaryIndex, VectorStoreIndex
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.llms.groq import Groq
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from typing import Tuple
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.objects import ObjectIndex
from llama_index.core.agent import ReActAgent

# Function to process files and create document tools
def create_doc_tools(document_fp: str, doc_name: str, verbose: bool = True) -> Tuple[QueryEngineTool,]:
    documents = SimpleDirectoryReader(input_files=[document_fp]).load_data()

    Settings.llm = Groq(model="mixtral-8x7b-32768")
    Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")

    load_dir_path = f"/home/user/app/agentic_index/{doc_name}"
    storage_context = StorageContext.from_defaults(persist_dir=load_dir_path)
    vector_index = load_index_from_storage(storage_context)
    vector_query_engine = vector_index.as_query_engine()

    vector_tool = QueryEngineTool.from_defaults(
        name=f"{doc_name}_vector_query_engine_tool",
        query_engine=vector_query_engine,
        description=f"Useful for retrieving specific context from the {doc_name}.",
    )

    return vector_tool

# Function to find and sort .tex files
def find_tex_files(directory: str):
    tex_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(('.tex', '.txt')):
                file_path = os.path.abspath(os.path.join(root, file))
                tex_files.append(file_path)
    tex_files.sort()
    return tex_files

# Main app function
def main():
    st.title("PDF Question Answering with LangChain")

    # API Key input
    api_key = st.text_input("Enter your Groq API Key", type="password")

    if api_key:
        directory = '/home/user/app/rag_docs_final_review_tex_merged'
        tex_files = find_tex_files(directory)

        paper_to_tools_dict = {}
        for paper in tex_files:
            path = Path(paper)
            vector_tool = create_doc_tools(doc_name=path.stem, document_fp=path)
            paper_to_tools_dict[path.stem] = [vector_tool]

        initial_tools = [t for paper in tex_files for t in paper_to_tools_dict[Path(paper).stem]]

        obj_index = ObjectIndex.from_objects(
            initial_tools,
            index_cls=VectorStoreIndex,
        )

        obj_retriever = obj_index.as_retriever(similarity_top_k=6)

        llm = Groq(model="mixtral-8x7b-32768")

        context = """You are an agent designed to answer scientific queries over a set of given documents.
        Please always use the tools provided to answer a question. Do not rely on prior knowledge.
        """

        agent = ReActAgent.from_tools(
            tool_retriever=obj_retriever,
            llm=llm,
            verbose=True,
            context=context
        )

        user_prompt = st.text_input("Enter your question")

        if user_prompt:
            with st.spinner("Processing..."):
                response = agent.query(user_prompt)
                markdown_response = f"""
                    ### Query Response:
                    
                    {response}
                """
                st.write(markdown_response)

if __name__ == "__main__":
    main()