Spaces:
Running
Running
File size: 2,396 Bytes
23687d1 341f67a cd3709a 23687d1 cd3709a 23687d1 cd3709a 341f67a cd3709a 23687d1 341f67a 23687d1 341f67a cd3709a 23687d1 c1dc2ee 23687d1 8d5b271 23687d1 c1dc2ee 23687d1 c1dc2ee 23687d1 8d5b271 cd3709a 23687d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import argparse
from itertools import islice
from pathlib import Path
from tqdm import tqdm
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Qdrant
from loaders import get_loader, LOADER_NAMES
from config import DB_CONFIG
CHUNK_SIZE = 500
def get_text_chunk(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=0
)
texts = text_splitter.split_documents(docs)
return texts
def batched(iterable, *, size=100):
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if size < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, size)):
yield batch
def store(texts):
model_name = "intfloat/multilingual-e5-large"
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
db_url, db_api_key, db_collection_name = DB_CONFIG
for batch in tqdm(batched(texts, size=100)):
_ = Qdrant.from_documents(
batch,
embeddings,
url=db_url,
api_key=db_api_key,
collection_name=db_collection_name,
)
def get_parser():
p = argparse.ArgumentParser()
p.add_argument("index", type=str)
p.add_argument("inputfile", metavar="INPUTFILE", type=str)
p.add_argument("-l", "--loader", type=str, choices=LOADER_NAMES, required=True)
return p
def index_annotated_docs(docs, index):
for doc in docs:
doc.metadata["index"] = index
yield doc
def main():
"""
$ python store.py --loader wikipage "index" "FILE_PATH"
$ python store.py -l wikipage wiki data/wiki.json
$ python store.py -l rtdhtmlpage django ./docs.djangoproject.com/
"""
p = get_parser()
args = p.parse_args()
loader = get_loader(
args.loader,
inputfile=Path(args.inputfile),
)
docs = loader.lazy_load()
texts = get_text_chunk(index_annotated_docs(docs, args.index))
store(texts)
if __name__ == "__main__":
main()
|