|
import os |
|
import glob |
|
from typing import List |
|
from dotenv import load_dotenv |
|
import argparse |
|
|
|
from langchain.document_loaders import ( |
|
CSVLoader, |
|
EverNoteLoader, |
|
PDFMinerLoader, |
|
TextLoader, |
|
UnstructuredEmailLoader, |
|
UnstructuredEPubLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredMarkdownLoader, |
|
UnstructuredODTLoader, |
|
UnstructuredPowerPointLoader, |
|
UnstructuredWordDocumentLoader, |
|
) |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.docstore.document import Document |
|
from constants import CHROMA_SETTINGS |
|
|
|
|
|
load_dotenv() |
|
|
|
embeddings_model_name = "all-MiniLM-L6-v2" |
|
persist_directory = "db" |
|
model = "tiiuae/falcon-7b-instruct" |
|
|
|
LOADER_MAPPING = { |
|
".csv": (CSVLoader, {}), |
|
|
|
".docx": (UnstructuredWordDocumentLoader, {}), |
|
".enex": (EverNoteLoader, {}), |
|
".eml": (UnstructuredEmailLoader, {}), |
|
".epub": (UnstructuredEPubLoader, {}), |
|
".html": (UnstructuredHTMLLoader, {}), |
|
".md": (UnstructuredMarkdownLoader, {}), |
|
".odt": (UnstructuredODTLoader, {}), |
|
".pdf": (PDFMinerLoader, {}), |
|
".pptx": (UnstructuredPowerPointLoader, {}), |
|
".txt": (TextLoader, {"encoding": "utf8"}), |
|
|
|
} |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def load_single_document(file_path: str) -> Document: |
|
ext = "." + file_path.rsplit(".", 1)[-1] |
|
if ext in LOADER_MAPPING: |
|
loader_class, loader_args = LOADER_MAPPING[ext] |
|
loader = loader_class(file_path, **loader_args) |
|
return loader.load()[0] |
|
|
|
raise ValueError(f"Unsupported file extension '{ext}'") |
|
|
|
|
|
def load_documents(source_dir: str) -> List[Document]: |
|
|
|
all_files = [] |
|
for ext in LOADER_MAPPING: |
|
all_files.extend( |
|
glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True) |
|
) |
|
return [load_single_document(file_path) for file_path in all_files] |
|
|
|
|
|
def main(collection): |
|
|
|
embeddings_model_name = "all-MiniLM-L6-v2" |
|
persist_directory = "db" |
|
model = "tiiuae/falcon-7b-instruct" |
|
source_directory = "source_documents" |
|
os.makedirs(source_directory, exist_ok=True) |
|
|
|
print(f"Loading documents from {source_directory}") |
|
chunk_size = 500 |
|
chunk_overlap = 50 |
|
documents = load_documents(source_directory) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
texts = text_splitter.split_documents(documents) |
|
print(f"Loaded {len(documents)} documents from {source_directory}") |
|
print(f"Split into {len(texts)} chunks of text (max. {chunk_size} characters each)") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) |
|
|
|
|
|
db = Chroma.from_documents(texts, embeddings, collection_name=collection, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) |
|
db.persist() |
|
db = None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--collection", help="Saves the embedding in a collection name as specified") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
main(args.collection) |
|
|