Spaces:
Paused
Paused
import re | |
import uuid | |
import pandas as pd | |
import streamlit as st | |
import re | |
import matplotlib.pyplot as plt | |
import subprocess | |
import sys | |
import io | |
from utils.default_values import get_system_prompt, get_guidelines_dict | |
from utils.epfl_meditron_utils import get_llm_response | |
from utils.openai_utils import get_available_engines, get_search_query_type_options | |
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
from sklearn.metrics import classification_report | |
DATA_FOLDER = "data/" | |
POC_VERSION = "0.1.0" | |
MAX_QUESTIONS = 10 | |
AVAILABLE_LANGUAGES = ["DE", "EN", "FR"] | |
st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') | |
# Azure apparently truncates message if longer than 200, see | |
MAX_SYSTEM_MESSAGE_TOKENS = 200 | |
def format_question(q): | |
res = q | |
# Remove numerical prefixes, if any, e.g. '1. [...]' | |
if re.match(r'^[0-9].\s', q): | |
res = res[3:] | |
# Replace doc reference by doc name | |
if len(st.session_state["citations"]) > 0: | |
for source_ref in re.findall(r'\[doc[0-9]+\]', res): | |
citation_number = int(re.findall(r'[0-9]+', source_ref)[0]) | |
citation_index = citation_number - 1 if citation_number > 0 else 0 | |
citation = st.session_state["citations"][citation_index] | |
source_title = citation["title"] | |
res = res.replace(source_ref, '[' + source_title + ']') | |
return res.strip() | |
def get_text_from_row(text): | |
res = str(text) | |
if res == "nan": | |
return "" | |
return res | |
def get_questions_from_df(df, lang, test_scenario_name): | |
questions = [] | |
for i, row in df.iterrows(): | |
questions.append({ | |
"question": row[lang + ": Fragen"], | |
"answer": get_text_from_row(row[test_scenario_name]), | |
"question_id": uuid.uuid4() | |
}) | |
return questions | |
def get_questions(df, lead_symptom, lang, test_scenario_name): | |
print(str(st.session_state["lead_symptom"]) + " -> " + lead_symptom) | |
print(str(st.session_state["scenario_name"]) + " -> " + test_scenario_name) | |
if st.session_state["lead_symptom"] != lead_symptom or st.session_state["scenario_name"] != test_scenario_name: | |
st.session_state["lead_symptom"] = lead_symptom | |
st.session_state["scenario_name"] = test_scenario_name | |
symptom_col_name = st.session_state["language"] + ": Symptome" | |
df_questions = df[(df[symptom_col_name] == lead_symptom)] | |
st.session_state["questions"] = get_questions_from_df(df_questions, lang, test_scenario_name) | |
return st.session_state["questions"] | |
def display_streamlit_sidebar(): | |
st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) | |
st.sidebar.write('**Parameters**') | |
form = st.sidebar.form("config_form", clear_on_submit=True) | |
model_option = form.selectbox("Quickly select a model", ("llama", "meditron")) | |
model_repo_id = form.text_input(label="Repo", value=model_option)#value=st.session_state["model_repo_id"]) | |
model_filename = form.text_input(label="File name", value=st.session_state["model_filename"]) | |
model_type = form.text_input(label="Model type", value=st.session_state["model_type"]) | |
gpu_layers = form.slider('GPU Layers', min_value=0, | |
max_value=100, value=st.session_state['gpu_layers'], step=1) | |
system_prompt = "" | |
#form.text_area(label='System prompt', | |
# value=st.session_state["system_prompt"]) | |
temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0, | |
max_value=1.0, value=st.session_state['temperature'], step=0.1) | |
top_p = form.slider('top_p (0 = focused, 1 = broader answer range)', min_value=0.0, | |
max_value=1.0, value=st.session_state['top_p'], step=0.1) | |
form.write('Best practice is to only modify temperature or top_p, not both') | |
submitted = form.form_submit_button("Start session") | |
if submitted and not st.session_state['session_started']: | |
print('Parameters updated...') | |
restart_session() | |
st.session_state['session_started'] = True | |
st.session_state["model_repo_id"] = model_repo_id | |
st.session_state["model_filename"] = model_filename | |
st.session_state["model_type"] = model_type | |
st.session_state['gpu_layers'] = gpu_layers | |
st.session_state["questions"] = [] | |
st.session_state["lead_symptom"] = None | |
st.session_state["scenario_name"] = None | |
st.session_state["system_prompt"] = system_prompt | |
st.session_state['session_started'] = True | |
st.session_state["session_started"] = True | |
st.session_state["temperature"] = temperature | |
st.session_state["top_p"] = top_p | |
st.rerun() | |
def to_str(text): | |
res = str(text) | |
if res == "nan": | |
return " " | |
return " " + res | |
def set_df_prompts(path, sheet_name): | |
df_prompts = pd.read_excel(path, sheet_name, header=None) | |
for i in range(3, df_prompts.shape[0]): | |
df_prompts.iloc[2] += df_prompts.iloc[i].apply(to_str) | |
df_prompts = df_prompts.T | |
df_prompts = df_prompts[[0, 1, 2]] | |
df_prompts[0] = df_prompts[0].astype(str) | |
df_prompts[1] = df_prompts[1].astype(str) | |
df_prompts[2] = df_prompts[2].astype(str) | |
df_prompts.columns = ["Questionnaire", "Used Guideline", "Prompt"] | |
df_prompts = df_prompts[1:] | |
st.session_state["df_prompts"] = df_prompts | |
def handle_nbq_click(c): | |
question_without_source = re.sub(r'\[.*\]', '', c) | |
question_without_source = question_without_source.strip() | |
st.session_state['doctor_question'] = question_without_source | |
def get_doctor_question_value(): | |
if 'doctor_question' in st.session_state: | |
return st.session_state['doctor_question'] | |
return '' | |
def update_chat_history(dr_question, patient_reply): | |
print("update_chat_history" + str(dr_question) + " - " + str(patient_reply) + '...\n') | |
if dr_question is not None: | |
dr_msg = { | |
"role": "Doctor", | |
"content": dr_question | |
} | |
st.session_state["chat_history_array"].append(dr_msg) | |
if patient_reply is not None: | |
patient_msg = { | |
"role": "Patient", | |
"content": patient_reply | |
} | |
st.session_state["chat_history_array"].append(patient_msg) | |
return st.session_state["chat_history_array"] | |
def get_chat_history_string(chat_history): | |
res = '' | |
for i in chat_history: | |
if i["role"] == "Doctor": | |
res += '**Doctor**: ' + str(i["content"].strip()) + " \n " | |
elif i["role"] == "Patient": | |
res += '**Patient**: ' + str(i["content"].strip()) + " \n\n " | |
else: | |
raise Exception('Unknown role: ' + str(i["role"])) | |
return res | |
def restart_session(): | |
print("Resetting params...") | |
st.session_state["emg_class_enabled"] = False | |
st.session_state["enable_llm_summary"] = False | |
st.session_state["num_variants"] = 3 | |
st.session_state["lang_index"] = 0 | |
st.session_state["llm_message"] = "" | |
st.session_state["llm_messages"] = [] | |
st.session_state["triage_prompt_variants"] = ['''You are a telemedicine triage agent that decides between the following: | |
Emergency: Patient health is at risk if he doesn't speak to a Doctor urgently | |
Telecare: Patient can likely be treated remotely | |
General Practitioner: Patient should visit a GP for an ad-real consultation''', | |
'''You are a Doctor assistant that decides if a medical case can likely be treated remotely by a Doctor or not. | |
The remote Doctor can write prescriptions and request the patient to provide a picture. | |
Provide the triage recommendation first, and then explain your reasoning respecting the format given below: | |
Treat remotely: <your reasoning> | |
Treat ad-real: <your reasoning>''', | |
'''You are a medical triage agent working for the telemedicine Company Medgate based in Switzerland. | |
You decide if a case can be treated remotely or not, knowing that the remote Doctor can write prescriptions and request pictures. | |
Provide the triage recommendation first, and then explain your reasoning respecting the format given below: | |
Treat remotely: <your reasoning> | |
Treat ad-real: <your reasoning>'''] | |
st.session_state['nbqs'] = [] | |
st.session_state['citations'] = {} | |
st.session_state['past_messages'] = [] | |
st.session_state["last_request"] = None | |
st.session_state["last_proposal"] = None | |
st.session_state['doctor_question'] = '' | |
st.session_state['patient_reply'] = '' | |
st.session_state['chat_history_array'] = [] | |
st.session_state['chat_history'] = '' | |
st.session_state['feed_summary'] = '' | |
st.session_state['summary'] = '' | |
st.session_state["selected_guidelines"] = ["General"] | |
st.session_state["guidelines_dict"] = get_guidelines_dict() | |
st.session_state["triage_recommendation"] = '' | |
st.session_state["session_events"] = [] | |
def init_session_state(): | |
print('init_session_state()') | |
st.session_state['session_started'] = False | |
st.session_state['guidelines_ignored'] = False | |
st.session_state['model_index'] = 1 | |
st.session_state["model_repo_id"] = "TheBloke/meditron-7B-GGUF" | |
st.session_state["model_filename"] = "meditron-7b.Q5_K_S.gguf" | |
st.session_state["model_type"] = "llama" | |
st.session_state['gpu_layers'] = 1 | |
default_gender_index = 0 | |
st.session_state['gender'] = get_genders()[default_gender_index] | |
st.session_state['gender_index'] = default_gender_index | |
st.session_state['age'] = 30 | |
st.session_state['patient_medical_info'] = '' | |
default_search_query = 0 | |
st.session_state['search_query_type'] = get_search_query_type_options()[default_search_query] | |
st.session_state['search_query_type_index'] = default_search_query | |
st.session_state['engine'] = get_available_engines()[0] | |
st.session_state['temperature'] = 0.0 | |
st.session_state['top_p'] = 1.0 | |
st.session_state['feed_chat_transcript'] = '' | |
st.session_state["llm_model"] = True | |
st.session_state["hugging_face_models"] = True | |
st.session_state["local_models"] = True | |
restart_session() | |
st.session_state['system_prompt'] = get_system_prompt() | |
st.session_state['system_prompt_after_on_change'] = get_system_prompt() | |
st.session_state["summary"] = '' | |
def get_genders(): | |
return ['Male', 'Female'] | |
def display_session_overview(): | |
st.subheader('History of LLM queries') | |
st.write(st.session_state["llm_messages"]) | |
st.subheader("Session costs overview") | |
df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) | |
st.write(df_session_overview) | |
if "prompt_tokens" in df_session_overview: | |
prompt_tokens = df_session_overview["prompt_tokens"].sum() | |
st.write("Prompt tokens: " + str(prompt_tokens)) | |
prompt_cost = df_session_overview["prompt_cost_chf"].sum() | |
st.write("Prompt CHF: " + str(prompt_cost)) | |
completion_tokens = df_session_overview["completion_tokens"].sum() | |
st.write("Completion tokens: " + str(completion_tokens)) | |
completion_cost = df_session_overview["completion_cost_chf"].sum() | |
st.write("Completion CHF: " + str(completion_cost)) | |
completion_cost = df_session_overview["total_cost_chf"].sum() | |
st.write("Total costs CHF: " + str(completion_cost)) | |
total_time = df_session_overview["response_time"].sum() | |
st.write("Total compute time (ms): " + str(total_time)) | |
def remove_question(question_id): | |
st.session_state["questions"] = [value for value in st.session_state["questions"] if | |
str(value["question_id"]) != str(question_id)] | |
st.rerun() | |
def get_prompt_from_lead_symptom(df_config, df_prompt, lead_symptom, lang, fallback=True): | |
de_lead_symptom = lead_symptom | |
if lang != "DE": | |
df_lead_symptom = df_config[df_config[lang + ": Symptome"] == lead_symptom] | |
de_lead_symptom = df_lead_symptom["DE: Symptome"].iloc[0] | |
print("DE lead symptom: " + de_lead_symptom) | |
for i, row in df_prompt.iterrows(): | |
if de_lead_symptom in row["Questionnaire"]: | |
return row["Prompt"] | |
warning_text = "No guidelines found for lead symptom " + lead_symptom + " (" + de_lead_symptom + ")" | |
if fallback: | |
st.toast(warning_text + ", using generic prompt", icon='🚨') | |
return st.session_state["system_prompt"] | |
st.toast(warning_text, icon='🚨') | |
return "" | |
def get_scenarios(df): | |
return [v for v in df.columns.values if v.startswith('TLC') or v.startswith('GP')] | |
def get_gender_age_from_test_scenario(test_scenario): | |
try: | |
result = re.search(r"([FM])(\d+)", test_scenario) | |
res_age = int(result.group(2)) | |
gender = result.group(1) | |
res_gender = None | |
if gender == "M": | |
res_gender = "Male" | |
elif gender == "F": | |
res_gender = "Female" | |
else: | |
raise Exception('Unexpected gender') | |
return res_gender, res_age | |
except: | |
st.error("Unable to extract name, gender; using 30M as default") | |
return "Male", 30 | |
def get_freetext_to_reco(reco_freetext_cased, emg_class_enabled=False): | |
reco_freetext = "" | |
if reco_freetext_cased: | |
reco_freetext = reco_freetext_cased.lower() | |
if reco_freetext.startswith('treat remotely') or reco_freetext.startswith('telecare'): | |
return 'TELECARE' | |
if reco_freetext.startswith('treat ad-real') or reco_freetext.startswith('gp') \ | |
or reco_freetext.startswith('general practitioner'): | |
return 'GP' | |
if reco_freetext.startswith('emergency') or reco_freetext.startswith('emg') \ | |
or reco_freetext.startswith('urgent'): | |
if emg_class_enabled: | |
return 'EMERGENCY' | |
return 'GP' | |
if "gp" in reco_freetext or 'general practitioner' in reco_freetext \ | |
or "nicht über tele" in reco_freetext or 'durch einen arzt erford' in reco_freetext \ | |
or "persönliche untersuchung erfordert" in reco_freetext: | |
return 'GP' | |
if ("telecare" in reco_freetext or 'telemed' in reco_freetext or | |
'can be treated remotely' in reco_freetext): | |
return 'TELECARE' | |
if ('emergency' in reco_freetext or 'urgent' in reco_freetext or | |
'not be treated remotely' in reco_freetext or "nicht tele" in reco_freetext): | |
return 'GP' | |
warning_msg = 'Cannot extract reco from LLM text: ' + reco_freetext | |
st.toast(warning_msg) | |
print(warning_msg) | |
return 'TRIAGE_IMPOSSIBLE' | |
def get_structured_reco(row, index, emg_class_enabled): | |
freetext_reco_col_name = "llm_reco_freetext_" + str(index) | |
freetext_reco = row[freetext_reco_col_name].lower() | |
return get_freetext_to_reco(freetext_reco, emg_class_enabled) | |
def add_expected_dispo(row, emg_class_enabled): | |
disposition = row["disposition"] | |
if disposition == "GP" or disposition == "TELECARE": | |
return disposition | |
if disposition == "EMERGENCY": | |
if emg_class_enabled: | |
return "EMERGENCY" | |
return "GP" | |
raise Exception("Missing disposition for row " + str(row.name) + " with summary " + row["case_summary"]) | |
def get_test_scenarios(df): | |
res = [] | |
for col in df.columns.values: | |
if str(col).startswith('GP') or str(col).startswith('TLC'): | |
res.append(col) | |
return res | |
def get_transcript(df, test_scenario, lang): | |
transcript = "" | |
for i, row in df.iterrows(): | |
transcript += "\nDoctor: " + row[lang + ": Fragen"] | |
transcript += ", Patient: " + str(row[test_scenario]) | |
return transcript | |
def get_expected_from_scenario(test_scenario): | |
reco = test_scenario.split('_')[0] | |
if reco == "GP": | |
return "GP" | |
elif reco == "TLC": | |
return "TELECARE" | |
else: | |
raise Exception('Unexpected reco: ' + reco) | |
def plot_report(title, expected, predicted, display_labels): | |
st.markdown('#### ' + title) | |
conf_matrix = confusion_matrix(expected, predicted, labels=display_labels) | |
conf_matrix_plot = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=display_labels) | |
conf_matrix_plot.plot() | |
st.pyplot(plt.gcf()) | |
report = classification_report(expected, predicted, output_dict=True) | |
df_report = pd.DataFrame(report).transpose() | |
st.write(df_report) | |
df_rp = df_report | |
df_rp = df_rp.drop('support', axis=1) | |
df_rp = df_rp.drop(['accuracy', 'macro avg', 'weighted avg']) | |
try: | |
ax = df_rp.plot(kind="bar", legend=True) | |
for container in ax.containers: | |
ax.bar_label(container, fontsize=7) | |
plt.xticks(rotation=45) | |
plt.legend(loc=(1.04, 0)) | |
st.pyplot(plt.gcf()) | |
except Exception as e: | |
# Out of bounds | |
pass | |
def get_complete_prompt(generic_prompt, guidelines_prompt): | |
complete_prompt = "" | |
if generic_prompt: | |
complete_prompt += generic_prompt | |
if generic_prompt and guidelines_prompt: | |
complete_prompt += ".\n\n" | |
if guidelines_prompt: | |
complete_prompt += guidelines_prompt | |
return complete_prompt | |
def run_command(args): | |
"""Run command, transfer stdout/stderr back into Streamlit and manage error""" | |
cmd = ' '.join(args) | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
print(result) | |
def get_diarized_f_path(audio_f_name): | |
# TODO p2: Quick hack, cleaner with os or regexes | |
base_name = audio_f_name.split('.')[0] | |
return DATA_FOLDER + base_name + ".txt" | |
def display_llm_output(): | |
st.header("LLM") | |
form = st.form('llm') | |
llm_message = form.text_area('Message', value=st.session_state["llm_message"]) | |
api_submitted = form.form_submit_button('Submit') | |
if api_submitted: | |
llm_response = get_llm_response( | |
st.session_state["model_repo_id"], | |
st.session_state["model_filename"], | |
st.session_state["model_type"], | |
st.session_state["gpu_layers"], | |
llm_message) | |
st.write(llm_response) | |
st.write('Done displaying LLM response') | |
def main(): | |
print('Running Local LLM PoC Streamlit app...') | |
session_inactive_info = st.empty() | |
if "session_started" not in st.session_state or not st.session_state["session_started"]: | |
init_session_state() | |
display_streamlit_sidebar() | |
else: | |
display_streamlit_sidebar() | |
session_inactive_info.empty() | |
display_llm_output() | |
display_session_overview() | |
if __name__ == '__main__': | |
main() | |