Spaces:
Paused
Paused
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import datetime | |
import pickle | |
import os | |
import csv | |
import torch | |
from tqdm.auto import tqdm | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# from langchain.vectorstores import Chroma | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
st.set_page_config( | |
page_title = 'aitGPT', | |
page_icon = '✅') | |
def load_scraped_web_info(): | |
with open("ait-web-document", "rb") as fp: | |
ait_web_documents = pickle.load(fp) | |
text_splitter = RecursiveCharacterTextSplitter( | |
# Set a really small chunk size, just to show. | |
chunk_size = 500, | |
chunk_overlap = 100, | |
length_function = len, | |
) | |
chunked_text = text_splitter.create_documents([doc for doc in tqdm(ait_web_documents)]) | |
# st.markdown(f"Number of Documents: {len(ait_web_documents)}") | |
# st.markdown(f"Number of chunked texts: {len(chunked_text)}") | |
def load_embedding_model(): | |
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base', | |
model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')}) | |
return embedding_model | |
def load_faiss_index(): | |
vector_database = FAISS.load_local("faiss_index", embedding_model) | |
return vector_database | |
def load_llm_model(): | |
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
# task= 'text2text-generation', | |
# model_kwargs={ "device_map": "auto", | |
# "load_in_8bit": True,"max_length": 256, "temperature": 0, | |
# "repetition_penalty": 1.5}) | |
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0', | |
task= 'text2text-generation', | |
model_kwargs={ "max_length": 256, "temperature": 0, | |
"torch_dtype":torch.float32, | |
"repetition_penalty": 1.3}) | |
return llm | |
def load_retriever(llm, db): | |
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", | |
retriever=db.as_retriever()) | |
return qa_retriever | |
#-------------- | |
if "history" not in st.session_state: | |
st.session_state.history = [] | |
if "session_rating" not in st.session_state: | |
st.session_state.session_rating = 0 | |
def update_score(): | |
st.session_state.session_rating = st.session_state.rating | |
load_scraped_web_info() | |
embedding_model = load_embedding_model() | |
vector_database = load_faiss_index() | |
llm_model = load_llm_model() | |
qa_retriever = load_retriever(llm= llm_model, db= vector_database) | |
print("all load done") | |
query_input = st.text_input(label= 'your question') | |
def retrieve_document(query_input): | |
related_doc = vector_database.similarity_search(query_input) | |
return related_doc | |
def retrieve_answer(query_input): | |
prompt_answer= query_input + " " + "Try to elaborate as much as you can." | |
answer = qa_retriever.run(prompt_answer) | |
output = st.text_area(label="Retrieved documents", value=answer) | |
st.markdown('---') | |
score = st.radio(label = 'please select the overall satifaction and helpfullness of the bot answer', options=[1,2,3,4,5], horizontal=True, | |
on_change=update_score, key='rating') | |
return answer | |
st.write("# aitGPT 🤖 ") | |
st.markdown(""" | |
#### The aitGPT project is a virtual assistant developed by the :green[Asian Institute of Technology] that contains a vast amount of information gathered from 205 AIT-related websites. | |
The goal of this chatbot is to provide an alternative way for applicants and current students to access information about the institute, including admission procedures, campus facilities, and more. | |
""") | |
st.write(' ⚠️ Please expect to wait **~ 10 - 20 seconds per question** as thi app is running on CPU against 3-billion-parameter LLM') | |
st.markdown("---") | |
query_input = st.text_area(label= 'What would you like to know about AIT?') | |
generate_button = st.button(label = 'Submit!') | |
if generate_button: | |
answer = retrieve_answer(query_input) | |
log = {"timestamp": datetime.datetime.now(), | |
"question":query_input, | |
"generated_answer": answer, | |
"rating":st.session_state.session_rating } | |
st.session_state.history.append(log) | |
if st.session_state.session_rating == 0: | |
pass | |
else: | |
with open('test_db', 'a') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow([st.session_state.history[-1]['timestamp'], st.session_state.history[-1]['question'], | |
st.session_state.history[-1]['generated_answer'], st.session_state.session_rating ]) | |
st.session_state.session_rating = 0 | |
test_df = pd.read_csv("test_db", | |
index_col=0) | |
test_df.sort_values(by = ['timestamp'], | |
axis=0, | |
ascending=False, | |
inplace=True) | |
st.dataframe(test_df) |