Ubai's picture
Update app.py
339ce69 verified
raw
history blame
2.7 kB
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate
# Default LLM model
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
# Other settings
default_persist_directory = './chroma_HF/'
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
"google/gemma-7b-it","google/gemma-2b-it", \
"HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
"google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Load vector database
def load_db():
embedding = HuggingFaceEmbeddings()
vectordb = Chroma(
persist_directory=default_persist_directory,
embedding_function=embedding)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(vector_db, progress=gr.Progress()):
progress(0.5, desc="Initializing HF Hub...")
# Use of trust_remote_code as model_kwargs
# Warning: langchain issue
# URL: https://github.com/langchain-ai/langchain/issues/6080
if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True}
)
# ... (other model configurations for different model options)
else:
llm = HuggingFaceHub(
repo_id=llm_model,
model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3}
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever=vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
# ... (other functions remain the same)