leo-bourrel's picture
!feat: Import new sorbobot version
68a9b68
raw
history blame
No virus
6.42 kB
import json
import os
import streamlit as st
import streamlit.components.v1 as components
from chain import get_chain
from chat_history import insert_chat_history, insert_chat_history_articles
from connection import connect
from css import load_css
from langchain.callbacks import get_openai_callback
from message import Message
st.set_page_config(layout="wide")
st.title("Sorbobot - Le futur de la recherche scientifique interactive")
chat_column, doc_column = st.columns([2, 1])
conn = connect()
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = []
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if "conversation" not in st.session_state:
st.session_state.conversation = get_chain(conn)
def send_message_callback():
with st.spinner("Wait for it..."):
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt.strip()
if len(human_prompt) == 0:
return
llm_response = st.session_state.conversation(human_prompt)
st.session_state.history.append(Message("human", human_prompt))
st.session_state.history.append(
Message(
"ai",
llm_response["answer"],
documents=llm_response["source_documents"],
)
)
st.session_state.token_count += cb.total_tokens
if os.environ.get("ENVIRONMENT") == "dev":
history_id = insert_chat_history(
conn, human_prompt, llm_response["answer"]
)
insert_chat_history_articles(
conn, history_id, llm_response["source_documents"]
)
def exemple_message_callback_button(args):
st.session_state.human_prompt = args
send_message_callback()
st.session_state.human_prompt = ""
def clear_history():
st.session_state.history.clear()
st.session_state.token_count = 0
st.session_state.conversation.memory.clear()
load_css()
initialize_session_state()
exemples = [
"Who has published influential research on quantum computing?",
"List any prominent authors in the field of artificial intelligence ethics?",
"Who are the leading experts on climate change mitigation strategies?",
]
with chat_column:
chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form", clear_on_submit=True)
information_placeholder = st.container()
with chat_placeholder:
div = f"""
<div class="chat-row">
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/512/1129/1129398.png" width=32 height=32>
<div class="chat-bubble ai-bubble">
Welcome to SorboBot, a Hugging Face Space designed to revolutionize the way you find published articles. <br/>
Powered by a full export from ScanR and Hal at Sorbonne University, SorboBot utilizes advanced language model technology
to provide you with a list of published articles based on your prompt.
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for chat in st.session_state.history:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/512/{
'1129/1129398.png' if chat.origin == 'ai'
else '1077/1077012.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=send_message_callback,
)
if st.session_state.token_count == 0:
information_placeholder.markdown("### Test me !")
for idx_exemple, exemple in enumerate(exemples):
information_placeholder.button(
exemple,
key=f"{idx_exemple}_button",
on_click=exemple_message_callback_button,
args=(exemple,),
)
st.button(
":new: Start a new conversation", on_click=clear_history, type="secondary"
)
if os.environ.get("ENVIRONMENT") == "dev":
information_placeholder.caption(
f"""
Used {st.session_state.token_count} tokens \n
Debug Langchain conversation:
{st.session_state.history}
"""
)
components.html(
"""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)
with doc_column:
st.markdown("**Source documents**")
if len(st.session_state.history) > 0:
for doc in st.session_state.history[-1].documents:
doc_content = json.loads(doc.page_content)
doc_metadata = doc.metadata
expander = st.expander(doc_content["title"])
expander.markdown(
f"**HalID** : https://hal.science/{doc_metadata['hal_id']}"
)
expander.markdown(doc_metadata["abstract"])
expander.markdown(f"**Authors** : {doc_content['authors']}")
expander.markdown(f"**Keywords** : {doc_content['keywords']}")
expander.markdown(f"**Distance** : {doc_metadata['distance']}")