File size: 2,704 Bytes
3ec9224
5be8df6
339ce69
 
80f5e00
339ce69
 
 
 
af00b58
339ce69
1ef8d7c
 
339ce69
af00b58
 
 
339ce69
af00b58
 
 
339ce69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af00b58
339ce69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import os

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings 
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from pathlib import Path
import chromadb

from transformers import AutoTokenizer
import transformers
import torch
import tqdm 
import accelerate


# Default LLM model
llm_model = "mistralai/Mistral-7B-Instruct-v0.2"

# Other settings
default_persist_directory = './chroma_HF/'
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
    "google/gemma-7b-it","google/gemma-2b-it", \
    "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
    "google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# Load vector database
def load_db():
    embedding = HuggingFaceEmbeddings()
    vectordb = Chroma(
        persist_directory=default_persist_directory, 
        embedding_function=embedding)
    return vectordb


# Initialize langchain LLM chain
def initialize_llmchain(vector_db, progress=gr.Progress()):
    progress(0.5, desc="Initializing HF Hub...")
    # Use of trust_remote_code as model_kwargs
    # Warning: langchain issue
    # URL: https://github.com/langchain-ai/langchain/issues/6080
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        llm = HuggingFaceHub(
            repo_id=llm_model, 
            model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3, "load_in_8bit": True}
        )
    # ... (other model configurations for different model options)
    else:
        llm = HuggingFaceHub(
            repo_id=llm_model, 
            model_kwargs={"temperature": 0.7, "max_new_tokens": 1024, "top_k": 3}
        )
    
    progress(0.75, desc="Defining buffer memory...")
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    retriever=vector_db.as_retriever()
    progress(0.8, desc="Defining retrieval chain...")
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    progress(0.9, desc="Done!")
    return qa_chain


# ... (other functions remain the same)