boryasbora's picture
Update app.py
25857fb verified
import streamlit as st
import os
import shutil
import schedule
import time
import pickle
from langchain.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from datetime import date
import time
import subprocess
import threading
llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1']
blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
# Environment variables
directory_path = "ohw_proj_chorma_db"
file_path = "ohw_proj_chorma_db.pcl"
# Function to update your retriever
# Function to update your retriever
def load_from_pickle(filename):
with open(filename, "rb") as file:
return pickle.load(file)
def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter):
"""Loads the vector store and document store, initializing the retriever."""
db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time
embedding_function=embeddings,
persist_directory=chroma_path
)
store_dict = load_from_pickle(docstore_path)
store = InMemoryStore()
store.mset(list(store_dict.items()))
retriever = ParentDocumentRetriever(
vectorstore=db3,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": 5}
)
return retriever
def inspect(state):
if "context_sources" not in st.session_state:
st.session_state.context_sources = []
context = state['normal_context']
st.session_state.context_sources =[doc.metadata['source'] for doc in context]
st.session_state.context_content = [doc.page_content for doc in context]
return state
def retrieve_normal_context(retriever, question):
docs = retriever.invoke(question)
return docs
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
@st.cache_resource
def get_chain(temperature,selected_model):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
docstore_path = 'ohw_proj_chorma_db.pcl'
chroma_path = 'ohw_proj_chorma_db'
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000,
chunk_overlap=500)
# create the child documents - The small chunks
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
chunk_overlap=50)
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
# llm_api = 'glpat-AMzMevbqaVjp4HbLcVum'
llm_api = os.getenv("blablador_api")
llm = ChatOpenAI(model_name=selected_model,
temperature=temperature,
openai_api_key=llm_api,
openai_api_base=blablador_base,
streaming=True)
today = date.today()
# Response prompt
response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant.
Keep track of chat history: {chat_history}
Today's date: {date}
## Normal Context:
{normal_context}
# Original Question: {question}
# Answer:
"""
response_prompt = ChatPromptTemplate.from_template(response_prompt_template)
context_chain = RunnableLambda(lambda x: {
"question": x["question"],
"normal_context": retrieve_normal_context(retriever,x["question"]),
# "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})),
"chat_history": x["chat_history"],
"date": today})
chain = (
context_chain
| RunnableLambda(inspect)
| response_prompt
| llm
| StrOutputParser()
)
return chain
def clear_chat_history():
st.session_state.messages = []
st.session_state.context_sources = []
st.session_state.key = 0
# Sidebar
with st.sidebar:
st.image("logo_no_bg.png", use_column_width=True) # Replace with your actual file path
st.title("OHW Assistant")
selected_model = st.sidebar.selectbox('Choose a LLM model',
llm_list,
key='selected_model',
index = 1)
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1,
help=("Controls the creativity of responses.\n"
"Lower values make answers more focused.\n"
"Higher values introduce more variety."))
if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1']:
if selected_model == 'Mistral-7B-Instruct-v0.2':
selected_model = 'alias-fast'
elif selected_model == 'Mixtral-8x7B-Instruct-v0.1':
selected_model = 'alias-large'
chain = get_chain(temperature,selected_model)
st.button('Clear Chat History', on_click=clear_chat_history)
# Main app
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = []
if "context_sources" not in st.session_state:
st.session_state.context_sources = []
if "context_content" not in st.session_state:
st.session_state.context_content = []
for q, message in enumerate(st.session_state.messages):
if (message["role"] == 'assistant'):
with st.chat_message(message["role"]):
tab1, tab2 = st.tabs(["Answer", "Sources"])
with tab1:
st.markdown(message["content"])
with tab2:
for i, source in enumerate(message["sources"]):
name = f'{source}'
with st.expander(name):
st.markdown(f'{message["context"][i]}')
else:
question = message["content"]
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("How may I assist you today?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
query=st.session_state.messages[-1]['content']
tab1, tab2 = st.tabs(["Answer", "Sources"])
with tab1:
start_time = time.time()
placeholder = st.empty() # Create a placeholder in Streamlit
full_answer = ""
for chunk in chain.stream({"question": query, "chat_history":st.session_state.messages}):
full_answer += chunk
placeholder.markdown(full_answer,unsafe_allow_html=True)
end_time = time.time()
st.caption(f"Response time: {end_time - start_time:.2f} seconds")
with tab2:
if st.session_state.context_sources:
for i, source in enumerate(st.session_state.context_sources):
name = f'{source}'
with st.expander(name):
st.markdown(f'{st.session_state.context_content[i]}')
else:
st.write("No sources available for this query.")
st.session_state.messages.append({"role": "assistant", "content": full_answer})
st.session_state.messages[-1]['sources'] = st.session_state.context_sources
st.session_state.messages[-1]['context'] = st.session_state.context_content