performance optimization and bug fix
Browse files- app.py +4 -7
- privateGPT.py +5 -5
- streamlit_app.py +1 -5
app.py
CHANGED
@@ -11,9 +11,6 @@ import transformers
|
|
11 |
from torch import cuda, bfloat16
|
12 |
|
13 |
|
14 |
-
|
15 |
-
load_dotenv()
|
16 |
-
|
17 |
embeddings_model_name = "all-MiniLM-L6-v2"
|
18 |
persist_directory = "db"
|
19 |
model = "tiiuae/falcon-7b-instruct"
|
@@ -26,6 +23,10 @@ source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
|
|
26 |
|
27 |
from constants import CHROMA_SETTINGS
|
28 |
|
|
|
|
|
|
|
|
|
29 |
# async def test_embedding():
|
30 |
# # Create the folder if it doesn't exist
|
31 |
# os.makedirs(source_directory, exist_ok=True)
|
@@ -101,14 +102,10 @@ def embed_documents(files, collection_name: Optional[str] = None):
|
|
101 |
def retrieve_documents(query: str, collection_name:str):
|
102 |
target_source_chunks = 4
|
103 |
mute_stream = ""
|
104 |
-
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
105 |
db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
|
106 |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
107 |
# Prepare the LLM
|
108 |
callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()]
|
109 |
-
|
110 |
-
llm = HuggingFacePipeline.from_model_id(model_id=model, task="text-generation", device=0, model_kwargs={"temperature":0.1,"trust_remote_code": True, "max_length":100000, "top_p":0.15, "top_k":0, "repetition_penalty":1.1, "num_return_sequences":1,})
|
111 |
-
|
112 |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
|
113 |
|
114 |
# Get the answer from the chain
|
|
|
11 |
from torch import cuda, bfloat16
|
12 |
|
13 |
|
|
|
|
|
|
|
14 |
embeddings_model_name = "all-MiniLM-L6-v2"
|
15 |
persist_directory = "db"
|
16 |
model = "tiiuae/falcon-7b-instruct"
|
|
|
23 |
|
24 |
from constants import CHROMA_SETTINGS
|
25 |
|
26 |
+
|
27 |
+
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
|
28 |
+
llm = HuggingFacePipeline.from_model_id(model_id=model, task="text-generation", device=0, model_kwargs={"temperature":0.1,"trust_remote_code": True, "max_length":100000, "top_p":0.15, "top_k":0, "repetition_penalty":1.1, "num_return_sequences":1,})
|
29 |
+
|
30 |
# async def test_embedding():
|
31 |
# # Create the folder if it doesn't exist
|
32 |
# os.makedirs(source_directory, exist_ok=True)
|
|
|
102 |
def retrieve_documents(query: str, collection_name:str):
|
103 |
target_source_chunks = 4
|
104 |
mute_stream = ""
|
|
|
105 |
db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
|
106 |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
|
107 |
# Prepare the LLM
|
108 |
callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()]
|
|
|
|
|
|
|
109 |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
|
110 |
|
111 |
# Get the answer from the chain
|
privateGPT.py
CHANGED
@@ -7,12 +7,12 @@ import os
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
10 |
-
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
|
11 |
-
persist_directory = os.environ.get('PERSIST_DIRECTORY')
|
12 |
|
13 |
-
model_type = os.environ.get('MODEL_TYPE')
|
14 |
-
model_path = os.environ.get('MODEL_PATH')
|
15 |
-
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
16 |
|
17 |
from constants import CHROMA_SETTINGS
|
18 |
|
|
|
7 |
|
8 |
load_dotenv()
|
9 |
|
10 |
+
# embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
|
11 |
+
# persist_directory = os.environ.get('PERSIST_DIRECTORY')
|
12 |
|
13 |
+
# model_type = os.environ.get('MODEL_TYPE')
|
14 |
+
# model_path = os.environ.get('MODEL_PATH')
|
15 |
+
# model_n_ctx = os.environ.get('MODEL_N_CTX')
|
16 |
|
17 |
from constants import CHROMA_SETTINGS
|
18 |
|
streamlit_app.py
CHANGED
@@ -7,7 +7,7 @@ import socket
|
|
7 |
from urllib3.connection import HTTPConnection
|
8 |
from app import embed_documents, retrieve_documents
|
9 |
|
10 |
-
API_BASE_URL = os.environ.get("API_BASE_URL")
|
11 |
|
12 |
|
13 |
embeddings_model_name = "all-MiniLM-L6-v2"
|
@@ -86,7 +86,3 @@ def get_collection_names():
|
|
86 |
# else:
|
87 |
# st.error("Failed to retrieve documents.")
|
88 |
# st.write(response.text)
|
89 |
-
|
90 |
-
|
91 |
-
if __name__ == "__main__":
|
92 |
-
main()
|
|
|
7 |
from urllib3.connection import HTTPConnection
|
8 |
from app import embed_documents, retrieve_documents
|
9 |
|
10 |
+
# API_BASE_URL = os.environ.get("API_BASE_URL")
|
11 |
|
12 |
|
13 |
embeddings_model_name = "all-MiniLM-L6-v2"
|
|
|
86 |
# else:
|
87 |
# st.error("Failed to retrieve documents.")
|
88 |
# st.write(response.text)
|
|
|
|
|
|
|
|