from openai import OpenAI import streamlit as st import streamlit.components.v1 as components import datetime, time from dataclasses import dataclass import math import base64 ## Firestore ?? import os # import sys # import inspect # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) # parentdir = os.path.dirname(currentdir) # sys.path.append(parentdir) import openai from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings import tiktoken from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from operator import itemgetter from langchain.schema import StrOutputParser from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel from langchain_core.runnables import chain import langchain_community.embeddings.huggingface from langchain_community.embeddings.huggingface import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import LLMChain from langchain.chains.conversation.memory import ConversationBufferWindowMemory #, ConversationBufferMemory, ConversationSummaryMemory, ConversationSummaryBufferMemory import os, dotenv from dotenv import load_dotenv load_dotenv() import firebase_admin, json from firebase_admin import credentials, storage, firestore import plotly.express as px import plotly.graph_objects as go import pandas as pd import networkx as nx if not os.path.isdir("./.streamlit"): os.mkdir("./.streamlit") print('made streamlit folder') if not os.path.isfile("./.streamlit/secrets.toml"): with open("./.streamlit/secrets.toml", "w") as f: f.write(os.environ.get("STREAMLIT_SECRETS")) print('made new file') import os, dotenv from dotenv import load_dotenv load_dotenv() if not os.path.isdir("./.streamlit"): os.mkdir("./.streamlit") print('made streamlit folder') if not os.path.isfile("./.streamlit/secrets.toml"): with open("./.streamlit/secrets.toml", "w") as f: f.write(os.environ.get("STREAMLIT_SECRETS")) print('made new file') import db_firestore as db ## Load from streamlit!! os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN") or st.secrets["HF_TOKEN"] os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") or st.secrets["OPENAI_API_KEY"] os.environ["FIREBASE_CREDENTIAL"] = os.environ.get("FIREBASE_CREDENTIAL") or st.secrets["FIREBASE_CREDENTIAL"] if "openai_model" not in st.session_state: st.session_state["openai_model"] = "gpt-3.5-turbo-1106" ## Hardcode indexes for now ## TODO: Move indexes to firebase indexes = """Bleeding ChestPain Dysphagia Headache ShortnessOfBreath Vomiting Weakness Weakness2""".split("\n") model_name = "BAAI/bge-large-en-v1.5" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": True} if "embeddings" not in st.session_state: st.session_state.embeddings = HuggingFaceBgeEmbeddings( model_name=model_name, model_kwargs = model_kwargs, encode_kwargs = encode_kwargs) embeddings = st.session_state.embeddings if "llm" not in st.session_state: st.session_state.llm = ChatOpenAI(model_name="gpt-3.5-turbo-1106", temperature=0) llm = st.session_state.llm if "llm_i" not in st.session_state: st.session_state.llm_i = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0) llm_i = st.session_state.llm_i if "llm_gpt4" not in st.session_state: st.session_state.llm_gpt4 = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) llm_gpt4 = st.session_state.llm_gpt4 if "TEMPLATE" not in st.session_state: with open('templates/patient.txt', 'r') as file: TEMPLATE = file.read() st.session_state.TEMPLATE = TEMPLATE TEMPLATE = st.session_state.TEMPLATE prompt = PromptTemplate( input_variables = ["question", "context"], template = st.session_state.TEMPLATE ) def format_docs(docs): return "\n--------------------\n".join(doc.page_content for doc in docs) sp_mapper = {"human":"student","ai":"patient", "user":"student","assistant":"patient"} if "TEMPLATE2" not in st.session_state: with open('templates/grader.txt', 'r') as file: TEMPLATE2 = file.read() st.session_state.TEMPLATE2 = TEMPLATE2 TEMPLATE2 = st.session_state.TEMPLATE2 prompt2 = PromptTemplate( input_variables = ["question", "context", "history"], template = st.session_state.TEMPLATE2 ) @chain def get_patient_chat_history(_): return st.session_state.get("patient_chat_history") if not st.session_state.get("scenario_list", None): st.session_state.scenario_list = indexes def init_patient_llm(): index_name = f"indexes/{st.session_state.scenario_list[st.session_state.selected_scenario]}/QA" if "store" not in st.session_state: st.session_state.store = db.get_store(index_name, embeddings=embeddings) if "retriever" not in st.session_state: st.session_state.retriever = st.session_state.store.as_retriever(search_type="similarity", search_kwargs={"k":2}) if "memory" not in st.session_state: st.session_state.memory = ConversationBufferWindowMemory( llm=llm, memory_key="chat_history", input_key="question", k=5, human_prefix="student", ai_prefix="patient",) if ("chain" not in st.session_state or st.session_state.TEMPLATE != TEMPLATE): st.session_state.chain = ( RunnableParallel({ "context": st.session_state.retriever | format_docs, "question": RunnablePassthrough() }) | LLMChain(llm=llm, prompt=prompt, memory=st.session_state.memory, verbose=False) ) # def init_grader_llm(): login_info = { "bob":"builder", "student1": "password", "admin":"admin" } def set_username(x): st.session_state.username = x def validate_username(username, password): if login_info.get(username) == password: set_username(username) else: st.warning("Wrong username or password") return None if not st.session_state.get("username"): ## ask to login st.title("Login") username = st.text_input("Username:") password = st.text_input("Password:", type="password") login_button = st.button("Login", on_click=validate_username, args=[username, password]) ll, rr = st.columns(2) ## TODO: Sync login info usernames to firebase, and remove this portion ll.header("Admin Login") ll.write("Username: admin") ll.write("Password: admin") rr.header("Student Login") rr.write("Username: student1") rr.write("Password: password") else: if True: ## Says hello and logout col_1, col_2 = st.columns([1,3]) col_2.title(f"Hello there, {st.session_state.username}") # Display logout button if col_1.button('Logout'): # Remove username from session state del st.session_state.username # Rerun the app to go back to the login view st.rerun() scenario_tab, dashboard_tab, generate_tab = st.tabs(["Training", "Dashboard", "Generate Scenario"]) class ScenarioTabIndex: SELECT_SCENARIO = 0 PATIENT_LLM = 1 GRADER_LLM = 2 def set_scenario_tab_index(x): st.session_state.scenario_tab_index=x return None def go_to_patient_llm(): selected_scenario = st.session_state.get('selected_scenario') if selected_scenario is None or selected_scenario < 0: st.warning("Please select a scenario!") else: st.session_state.start_time = datetime.datetime.utcnow() states = ["store", "store2","retriever","retriever2","chain","chain2"] for state_to_del in states: if state_to_del in st.session_state: del st.session_state[state_to_del] init_patient_llm() set_scenario_tab_index(ScenarioTabIndex.PATIENT_LLM) if not st.session_state.get("scenario_tab_index"): set_scenario_tab_index(ScenarioTabIndex.SELECT_SCENARIO) with scenario_tab: ## if True: ## Check in select scenario if st.session_state.scenario_tab_index == ScenarioTabIndex.SELECT_SCENARIO: def change_scenario(scenario_index): st.session_state.selected_scenario = scenario_index if st.session_state.get("selected_scenario", None) is None: st.session_state.selected_scenario = -1 total_cols = 3 rows = list() # for _ in range(0, number_of_indexes, total_cols): # rows.extend(st.columns(total_cols)) st.header(f"Selected Scenario: {st.session_state.scenario_list[st.session_state.selected_scenario] if st.session_state.selected_scenario>=0 else 'None'}") #st.button("Generate a new scenario") for i, scenario in enumerate(st.session_state.scenario_list): if i % total_cols == 0: rows.extend(st.columns(total_cols)) curr_col = rows[(-total_cols + i % total_cols)] tile = curr_col.container(height=120) ## TODO: Implement highlight box if index is selected # if st.session_state.selected_scenario == i: # tile.markdown("", unsafe_allow_html=True) tile.write(":balloon:") tile.button(label=scenario, on_click=change_scenario, args=[i]) select_scenario_btn = st.button("Select Scenario", on_click=go_to_patient_llm, args=[]) elif st.session_state.scenario_tab_index == ScenarioTabIndex.PATIENT_LLM: st.header("Patient info") ## TODO: Put the patient's info here, from SCENARIO # st.write("Pull the info here!!!") col1, col2, col3 = st.columns([1,3,1]) with col1: back_to_scenario_btn = st.button("Back to selection", on_click=set_scenario_tab_index, args=[ScenarioTabIndex.SELECT_SCENARIO]) # with col3: # start_timer_button = st.button("START") with col2: TIME_LIMIT = 60*10 ## to change to 10 minutes time.sleep(1) # if start_timer_button: # st.session_state.start_time = datetime.datetime.now() # st.session_state.time = -1 if not st.session_state.get('time') else st.session_state.get('time') st.session_state.start_time = False if not st.session_state.get('start_time') else st.session_state.start_time from streamlit.components.v1 import html html(f"""