RAG_UI / app.py
darthPanda's picture
added prompt tracing
6888345
raw
history blame
6.23 kB
import streamlit as st
import os
import embed_pdf
import shutil
from utils import make_discord_trace_text
make_discord_trace_text("RAG UI opened")
def clear_directory(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
def clear_pdf_files(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) and file_path.endswith('.pdf'):
os.remove(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# clear_pdf_files("pdf")
# clear_directory("index")
# create sidebar and ask for openai api key if not set in secrets
secrets_file_path = os.path.join(".streamlit", "secrets.toml")
# if os.path.exists(secrets_file_path):
# try:
# if "OPENAI_API_KEY" in st.secrets:
# os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# else:
# print("OpenAI API Key not found in environment variables")
# except FileNotFoundError:
# print('Secrets file not found')
# else:
# print('Secrets file not found')
# if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
# os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
# "OpenAI API Key", type="password"
# )
# else:
# if st.sidebar.button("Embed Documents"):
# st.sidebar.info("Embedding documents...")
# try:
# embed_pdf.embed_all_pdf_docs()
# st.sidebar.info("Done!")
# except Exception as e:
# st.sidebar.error(e)
# st.sidebar.error("Failed to embed documents.")
os.environ["OPENAI_API_KEY"] = st.sidebar.text_input(
"OpenAI API Key", type="password"
)
st.sidebar.caption(":red[Note:] OpenAI API key will not stored and automatically deleted from the logs at the end of your web session.")
st.sidebar.write("---")
uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf'], disabled=False)
if uploaded_file is None:
file_uploaded_bool = False
else:
file_uploaded_bool = True
if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool):
st.sidebar.info("Embedding documents...")
try:
embed_pdf.embed_all_inputed_pdf_docs(uploaded_file)
# embed_pdf.embed_all_pdf_docs()
st.sidebar.info("Done!")
except Exception as e:
st.sidebar.error(e)
st.sidebar.error("Failed to embed documents.")
st.sidebar.write("---")
st.sidebar.markdown('''
Steps to run app
1. Paste OpenAI API Key and press Enter
2. Upload PDF file
3. Click on Embed Documents button
4. Choose RAG method
5. Start Chatting with your PDF
''')
# create the app
st.title("Chat with your PDF")
# chosen_file = st.radio(
# "Choose a file to search", embed_pdf.get_all_index_files(), index=0
# )
# check if openai api key is set
if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"):
st.warning("Please enter your OpenAI API key!", icon="⚠")
st.stop()
# load the agent
from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain
rag_method_map = {
'Basic RAG': get_rag_chain,
'RAG Fusion': get_rag_fusion_chain
}
chosen_rag_method = st.radio(
"Choose a RAG method", rag_method_map.keys(), index=0
)
get_rag_chain_func = rag_method_map[chosen_rag_method]
## get the chain WITHOUT the retrieval callback (not used)
# custom_chain = get_rag_chain_func(chosen_file)
# create the message history state
if "messages" not in st.session_state:
st.session_state.messages = []
# render older messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# render the chat input
prompt = st.chat_input("Enter your message...")
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
# render the user's new message
with st.chat_message("user"):
st.markdown(prompt)
make_discord_trace_text(prompt)
# render the assistant's response
with st.chat_message("assistant"):
retrival_container = st.container()
message_placeholder = st.empty()
# retrieval_status = retrival_container.status("**Context Retrieval**")
queried_questions = []
rendered_questions = set()
def update_retrieval_status():
for q in queried_questions:
if q in rendered_questions:
continue
rendered_questions.add(q)
# retrieval_status.markdown(f"\n\n`- {q}`")
retrival_container.markdown(f"\n\n`- {q}`")
def retrieval_cb(qs):
for q in qs:
if q not in queried_questions:
queried_questions.append(q)
return qs
# get the chain with the retrieval callback
custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb)
if "messages" in st.session_state:
chat_history = [convert_message(m) for m in st.session_state.messages[:-1]]
else:
chat_history = []
full_response = ""
for response in custom_chain.stream(
{"input": prompt, "chat_history": chat_history}
):
if "output" in response:
full_response += response["output"]
else:
full_response += response.content
message_placeholder.markdown(full_response + "▌")
update_retrieval_status()
# retrival_container.update(state="complete")
# retrieval_status.update(state="complete")
message_placeholder.markdown(full_response)
make_discord_trace_text(full_response)
# add the full response to the message history
st.session_state.messages.append({"role": "assistant", "content": full_response})