import streamlit as st import os import shutil import schedule import time import pickle from langchain.prompts import ChatPromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from langchain.retrievers import ParentDocumentRetriever from langchain.storage import InMemoryStore from langchain_chroma import Chroma from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from datetime import date import time import subprocess import threading llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1'] blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1" # Environment variables directory_path = "ohw_proj_chorma_db" file_path = "ohw_proj_chorma_db.pcl" # Function to update your retriever # Function to update your retriever def load_from_pickle(filename): with open(filename, "rb") as file: return pickle.load(file) def load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter): """Loads the vector store and document store, initializing the retriever.""" db3 = Chroma(collection_name="full_documents", #collection_name shoud be the same as in the first time embedding_function=embeddings, persist_directory=chroma_path ) store_dict = load_from_pickle(docstore_path) store = InMemoryStore() store.mset(list(store_dict.items())) retriever = ParentDocumentRetriever( vectorstore=db3, docstore=store, child_splitter=child_splitter, parent_splitter=parent_splitter, search_kwargs={"k": 5} ) return retriever def inspect(state): if "context_sources" not in st.session_state: st.session_state.context_sources = [] context = state['normal_context'] st.session_state.context_sources =[doc.metadata['source'] for doc in context] st.session_state.context_content = [doc.page_content for doc in context] return state def retrieve_normal_context(retriever, question): docs = retriever.invoke(question) return docs # Your OLMOLLM class implementation here (adapted for the Hugging Face model) @st.cache_resource def get_chain(temperature,selected_model): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2") docstore_path = 'ohw_proj_chorma_db.pcl' chroma_path = 'ohw_proj_chorma_db' parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=500) # create the child documents - The small chunks child_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter) # llm_api = 'glpat-AMzMevbqaVjp4HbLcVum' llm_api = os.getenv("blablador_api") llm = ChatOpenAI(model_name=selected_model, temperature=temperature, openai_api_key=llm_api, openai_api_base=blablador_base, streaming=True) today = date.today() # Response prompt response_prompt_template = """You are an assistant who helps Ocean Hack Week community to answer their questions. I am going to ask you a question. Your response should be comprehensive and not contradicted with the following context if they are relevant. Otherwise, ignore them if they are not relevant. Keep track of chat history: {chat_history} Today's date: {date} ## Normal Context: {normal_context} # Original Question: {question} # Answer: """ response_prompt = ChatPromptTemplate.from_template(response_prompt_template) context_chain = RunnableLambda(lambda x: { "question": x["question"], "normal_context": retrieve_normal_context(retriever,x["question"]), # "step_back_context": retrieve_step_back_context(retriever,generate_queries_step_back.invoke({"question": x["question"]})), "chat_history": x["chat_history"], "date": today}) chain = ( context_chain | RunnableLambda(inspect) | response_prompt | llm | StrOutputParser() ) return chain def clear_chat_history(): st.session_state.messages = [] st.session_state.context_sources = [] st.session_state.key = 0 # Sidebar with st.sidebar: st.image("logo_no_bg.png", use_column_width=True) # Replace with your actual file path st.title("OHW Assistant") selected_model = st.sidebar.selectbox('Choose a LLM model', llm_list, key='selected_model', index = 1) temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1, help=("Controls the creativity of responses.\n" "Lower values make answers more focused.\n" "Higher values introduce more variety.")) if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1']: if selected_model == 'Mistral-7B-Instruct-v0.2': selected_model = 'alias-fast' elif selected_model == 'Mixtral-8x7B-Instruct-v0.1': selected_model = 'alias-large' chain = get_chain(temperature,selected_model) st.button('Clear Chat History', on_click=clear_chat_history) # Main app # Initialize session state variables if "messages" not in st.session_state: st.session_state.messages = [] if "context_sources" not in st.session_state: st.session_state.context_sources = [] if "context_content" not in st.session_state: st.session_state.context_content = [] for q, message in enumerate(st.session_state.messages): if (message["role"] == 'assistant'): with st.chat_message(message["role"]): tab1, tab2 = st.tabs(["Answer", "Sources"]) with tab1: st.markdown(message["content"]) with tab2: for i, source in enumerate(message["sources"]): name = f'{source}' with st.expander(name): st.markdown(f'{message["context"][i]}') else: question = message["content"] with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("How may I assist you today?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): query=st.session_state.messages[-1]['content'] tab1, tab2 = st.tabs(["Answer", "Sources"]) with tab1: start_time = time.time() placeholder = st.empty() # Create a placeholder in Streamlit full_answer = "" for chunk in chain.stream({"question": query, "chat_history":st.session_state.messages}): full_answer += chunk placeholder.markdown(full_answer,unsafe_allow_html=True) end_time = time.time() st.caption(f"Response time: {end_time - start_time:.2f} seconds") with tab2: if st.session_state.context_sources: for i, source in enumerate(st.session_state.context_sources): name = f'{source}' with st.expander(name): st.markdown(f'{st.session_state.context_content[i]}') else: st.write("No sources available for this query.") st.session_state.messages.append({"role": "assistant", "content": full_answer}) st.session_state.messages[-1]['sources'] = st.session_state.context_sources st.session_state.messages[-1]['context'] = st.session_state.context_content