|
from langchain.chains import RetrievalQA |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain.vectorstores import Chroma |
|
import os |
|
from typing import List, Optional |
|
import urllib.parse |
|
from langchain.llms import HuggingFacePipeline |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import transformers |
|
from torch import cuda, bfloat16 |
|
|
|
|
|
embeddings_model_name = "all-MiniLM-L6-v2" |
|
persist_directory = "db" |
|
model = "tiiuae/falcon-7b-instruct" |
|
|
|
|
|
|
|
|
|
|
|
source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents') |
|
|
|
from constants import CHROMA_SETTINGS |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) |
|
llm = HuggingFacePipeline.from_model_id(model_id=model, task="text-generation", device=0, model_kwargs={"temperature":0.1,"trust_remote_code": True, "max_length":100000, "top_p":0.15, "top_k":0, "repetition_penalty":1.1, "num_return_sequences":1,}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_documents(files, collection_name: Optional[str] = None): |
|
|
|
saved_files = [] |
|
|
|
for file in files: |
|
print(file) |
|
os.makedirs(source_directory, exist_ok= True) |
|
file_path = os.path.join(source_directory, file.name) |
|
saved_files.append(file_path) |
|
|
|
with open(file_path, "wb") as f: |
|
f.write(file.read()) |
|
|
|
if collection_name is None: |
|
|
|
collection_name = file.name |
|
|
|
os.system(f'python ingest.py --collection {collection_name}') |
|
|
|
|
|
[os.remove(os.path.join(source_directory, file.name)) or os.path.join(source_directory, file.name) for file in files] |
|
|
|
return {"message": "Files embedded successfully", "saved_files": saved_files} |
|
|
|
def retrieve_documents(query: str, collection_name:str): |
|
target_source_chunks = 4 |
|
mute_stream = "" |
|
db = Chroma(persist_directory=persist_directory,collection_name=collection_name, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) |
|
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) |
|
|
|
callbacks = [] if mute_stream else [StreamingStdOutCallbackHandler()] |
|
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False) |
|
|
|
|
|
res = qa(query) |
|
print(res) |
|
answer = res['result'] |
|
|
|
st.subheader("Results") |
|
st.text(answer) |
|
|