Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import shutil | |
import schedule | |
import time | |
import pickle | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.retrievers import ParentDocumentRetriever | |
from langchain.storage import InMemoryStore | |
from langchain_chroma import Chroma | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda | |
from datetime import date | |
import time | |
import subprocess | |
import threading | |
llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1'] | |
blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1" | |
# Environment variables | |
directory_path = "ohw_proj_chorma_db" | |
file_path = "ohw_proj_chorma_db.pcl" | |
# Function to update your retriever | |
# Function to update your retriever | |
def load_from_pickle(filename): | |
with open(filename, "rb") as file: | |
return pickle.load(file) | |
def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter): | |
"""Loads the vector store and document store, initializing the retriever.""" | |
db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time | |
embedding_function=embeddings, | |
persist_directory=chroma_path | |
) | |
store_dict = load_from_pickle(docstore_path) | |
store = InMemoryStore() | |
store.mset(list(store_dict.items())) | |
retriever = ParentDocumentRetriever( | |
vectorstore=db3, | |
docstore=store, | |
child_splitter=child_splitter, | |
parent_splitter=parent_splitter, | |
search_kwargs={"k": 5} | |
) | |
return retriever | |
def inspect(state): | |
if "context_sources" not in st.session_state: | |
st.session_state.context_sources = [] | |
context = state['normal_context'] | |
st.session_state.context_sources =[doc.metadata['source'] for doc in context] | |
st.session_state.context_content = [doc.page_content for doc in context] | |
return state | |
def retrieve_normal_context(retriever, question): | |
docs = retriever.invoke(question) | |
return docs | |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model) | |
def get_chain(temperature,selected_model): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") | |
docstore_path = 'ohw_proj_chorma_db.pcl' | |
chroma_path = 'ohw_proj_chorma_db' | |
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, | |
chunk_overlap=500) | |
# create the child documents - The small chunks | |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300, | |
chunk_overlap=50) | |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter) | |
# llm_api = 'glpat-AMzMevbqaVjp4HbLcVum' | |
llm_api = os.getenv("blablador_api") | |
llm = ChatOpenAI(model_name=selected_model, | |
temperature=temperature, | |
openai_api_key=llm_api, | |
openai_api_base=blablador_base, | |
streaming=True) | |
today = date.today() | |
# Response prompt | |
response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant. | |
Keep track of chat history: {chat_history} | |
Today's date: {date} | |
## Normal Context: | |
{normal_context} | |
# Original Question: {question} | |
# Answer: | |
""" | |
response_prompt = ChatPromptTemplate.from_template(response_prompt_template) | |
context_chain = RunnableLambda(lambda x: { | |
"question": x["question"], | |
"normal_context": retrieve_normal_context(retriever,x["question"]), | |
# "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})), | |
"chat_history": x["chat_history"], | |
"date": today}) | |
chain = ( | |
context_chain | |
| RunnableLambda(inspect) | |
| response_prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
def clear_chat_history(): | |
st.session_state.messages = [] | |
st.session_state.context_sources = [] | |
st.session_state.key = 0 | |
# Sidebar | |
with st.sidebar: | |
st.image("logo_no_bg.png", use_column_width=True) # Replace with your actual file path | |
st.title("OHW Assistant") | |
selected_model = st.sidebar.selectbox('Choose a LLM model', | |
llm_list, | |
key='selected_model', | |
index = 1) | |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1, | |
help=("Controls the creativity of responses.\n" | |
"Lower values make answers more focused.\n" | |
"Higher values introduce more variety.")) | |
if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1']: | |
if selected_model == 'Mistral-7B-Instruct-v0.2': | |
selected_model = 'alias-fast' | |
elif selected_model == 'Mixtral-8x7B-Instruct-v0.1': | |
selected_model = 'alias-large' | |
chain = get_chain(temperature,selected_model) | |
st.button('Clear Chat History', on_click=clear_chat_history) | |
# Main app | |
# Initialize session state variables | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "context_sources" not in st.session_state: | |
st.session_state.context_sources = [] | |
if "context_content" not in st.session_state: | |
st.session_state.context_content = [] | |
for q, message in enumerate(st.session_state.messages): | |
if (message["role"] == 'assistant'): | |
with st.chat_message(message["role"]): | |
tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
with tab1: | |
st.markdown(message["content"]) | |
with tab2: | |
for i, source in enumerate(message["sources"]): | |
name = f'{source}' | |
with st.expander(name): | |
st.markdown(f'{message["context"][i]}') | |
else: | |
question = message["content"] | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("How may I assist you today?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
query=st.session_state.messages[-1]['content'] | |
tab1, tab2 = st.tabs(["Answer", "Sources"]) | |
with tab1: | |
start_time = time.time() | |
placeholder = st.empty() # Create a placeholder in Streamlit | |
full_answer = "" | |
for chunk in chain.stream({"question": query, "chat_history":st.session_state.messages}): | |
full_answer += chunk | |
placeholder.markdown(full_answer,unsafe_allow_html=True) | |
end_time = time.time() | |
st.caption(f"Response time: {end_time - start_time:.2f} seconds") | |
with tab2: | |
if st.session_state.context_sources: | |
for i, source in enumerate(st.session_state.context_sources): | |
name = f'{source}' | |
with st.expander(name): | |
st.markdown(f'{st.session_state.context_content[i]}') | |
else: | |
st.write("No sources available for this query.") | |
st.session_state.messages.append({"role": "assistant", "content": full_answer}) | |
st.session_state.messages[-1]['sources'] = st.session_state.context_sources | |
st.session_state.messages[-1]['context'] = st.session_state.context_content | |