RishuD7 commited on
Commit
d27d5a2
1 Parent(s): ba5c062

performance optimization and bug fix

Browse files
Files changed (3) hide show
  1. app.py +4 -7
  2. privateGPT.py +5 -5
  3. streamlit_app.py +1 -5
app.py CHANGED
@@ -11,9 +11,6 @@ import transformers
11
  from torch import cuda, bfloat16
12
 
13
 
14
-
15
- load_dotenv()
16
-
17
  embeddings_model_name = "all-MiniLM-L6-v2"
18
  persist_directory = "db"
19
  model = "tiiuae/falcon-7b-instruct"
@@ -26,6 +23,10 @@ source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
26
 
27
  from constants import CHROMA_SETTINGS
28
 
 
 
 
 
29
  # async def test_embedding():
30
  # # Create the folder if it doesn't exist
31
  # os.makedirs(source_directory, exist_ok=True)
@@ -101,14 +102,10 @@ def embed_documents(files, collection_name: Optional[str] = None):
101
  def retrieve_documents(query: str, collection_name:str):
102
  target_source_chunks = 4
103
  mute_stream = ""
104
- embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
105
  db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
106
  retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
107
  # Prepare the LLM
108
  callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()]
109
-
110
- llm = HuggingFacePipeline.from_model_id(model_id=model, task="text-generation", device=0, model_kwargs={"temperature":0.1,"trust_remote_code": True, "max_length":100000, "top_p":0.15, "top_k":0, "repetition_penalty":1.1, "num_return_sequences":1,})
111
-
112
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
113
 
114
  # Get the answer from the chain
 
11
  from torch import cuda, bfloat16
12
 
13
 
 
 
 
14
  embeddings_model_name = "all-MiniLM-L6-v2"
15
  persist_directory = "db"
16
  model = "tiiuae/falcon-7b-instruct"
 
23
 
24
  from constants import CHROMA_SETTINGS
25
 
26
+
27
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
28
+ llm = HuggingFacePipeline.from_model_id(model_id=model, task="text-generation", device=0, model_kwargs={"temperature":0.1,"trust_remote_code": True, "max_length":100000, "top_p":0.15, "top_k":0, "repetition_penalty":1.1, "num_return_sequences":1,})
29
+
30
  # async def test_embedding():
31
  # # Create the folder if it doesn't exist
32
  # os.makedirs(source_directory, exist_ok=True)
 
102
  def retrieve_documents(query: str, collection_name:str):
103
  target_source_chunks = 4
104
  mute_stream = ""
 
105
  db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
106
  retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
107
  # Prepare the LLM
108
  callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()]
 
 
 
109
  qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
110
 
111
  # Get the answer from the chain
privateGPT.py CHANGED
@@ -7,12 +7,12 @@ import os
7
 
8
  load_dotenv()
9
 
10
- embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
11
- persist_directory = os.environ.get('PERSIST_DIRECTORY')
12
 
13
- model_type = os.environ.get('MODEL_TYPE')
14
- model_path = os.environ.get('MODEL_PATH')
15
- model_n_ctx = os.environ.get('MODEL_N_CTX')
16
 
17
  from constants import CHROMA_SETTINGS
18
 
 
7
 
8
  load_dotenv()
9
 
10
+ # embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
11
+ # persist_directory = os.environ.get('PERSIST_DIRECTORY')
12
 
13
+ # model_type = os.environ.get('MODEL_TYPE')
14
+ # model_path = os.environ.get('MODEL_PATH')
15
+ # model_n_ctx = os.environ.get('MODEL_N_CTX')
16
 
17
  from constants import CHROMA_SETTINGS
18
 
streamlit_app.py CHANGED
@@ -7,7 +7,7 @@ import socket
7
  from urllib3.connection import HTTPConnection
8
  from app import embed_documents, retrieve_documents
9
 
10
- API_BASE_URL = os.environ.get("API_BASE_URL")
11
 
12
 
13
  embeddings_model_name = "all-MiniLM-L6-v2"
@@ -86,7 +86,3 @@ def get_collection_names():
86
  # else:
87
  # st.error("Failed to retrieve documents.")
88
  # st.write(response.text)
89
-
90
-
91
- if __name__ == "__main__":
92
- main()
 
7
  from urllib3.connection import HTTPConnection
8
  from app import embed_documents, retrieve_documents
9
 
10
+ # API_BASE_URL = os.environ.get("API_BASE_URL")
11
 
12
 
13
  embeddings_model_name = "all-MiniLM-L6-v2"
 
86
  # else:
87
  # st.error("Failed to retrieve documents.")
88
  # st.write(response.text)