angry-meow commited on
Commit
31f9732
1 Parent(s): 7f24de7

Few new models

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
  .env
 
 
1
  .env
2
+ /__pycache__
__pycache__/constants.cpython-311.pyc CHANGED
Binary files a/__pycache__/constants.cpython-311.pyc and b/__pycache__/constants.cpython-311.pyc differ
 
__pycache__/models.cpython-311.pyc CHANGED
Binary files a/__pycache__/models.cpython-311.pyc and b/__pycache__/models.cpython-311.pyc differ
 
load_existing_docs.py CHANGED
@@ -6,12 +6,7 @@ from langchain_community.document_loaders import PyPDFLoader, UnstructuredURLLoa
6
  from qdrant_client.http.models import VectorParams
7
  import pymupdf
8
  import requests
9
-
10
- #qdrant = QdrantVectorStore.from_existing_collection(
11
- # embedding=models.basic_embeddings,
12
- # collection_name="kai_test_documents",
13
- # url=constants.QDRANT_ENDPOINT,
14
- #)
15
 
16
  def extract_links_from_pdf(pdf_path):
17
  links = []
@@ -78,26 +73,22 @@ for link in unique_links:
78
 
79
 
80
  #print(len(documents))
81
- semantic_split_docs = models.semanticChunker.split_documents(documents)
82
- RCTS_split_docs = models.RCTS.split_documents(documents)
83
-
84
-
85
- #for file in filepaths:
86
- # loader = PyPDFLoader(file)
87
- # documents = loader.load()
88
- # for doc in documents:
89
- # doc.metadata = {
90
- # "source": file,
91
- # "tag": "employee" if "employee" in file.lower() else "employer"
92
- # }
93
- # all_documents.extend(documents)
94
-
95
- #chunk them
96
- #semantic_split_docs = models.semanticChunker.split_documents(all_documents)
97
-
98
-
99
  #add them to the existing qdrant client
100
- collection_name = "docs_from_ripped_urls_recursive"
101
 
102
  collections = models.qdrant_client.get_collections()
103
  collection_names = [collection.name for collection in collections.collections]
@@ -105,16 +96,16 @@ collection_names = [collection.name for collection in collections.collections]
105
  if collection_name not in collection_names:
106
  models.qdrant_client.create_collection(
107
  collection_name=collection_name,
108
- vectors_config=VectorParams(size=1536, distance="Cosine")
109
  )
110
 
111
  qdrant_vector_store = QdrantVectorStore(
112
  client=models.qdrant_client,
113
  collection_name=collection_name,
114
- embedding=models.te3_small
115
  )
116
 
117
- qdrant_vector_store.add_documents(RCTS_split_docs)
118
 
119
 
120
 
 
6
  from qdrant_client.http.models import VectorParams
7
  import pymupdf
8
  import requests
9
+ from transformers import AutoTokenizer
 
 
 
 
 
10
 
11
  def extract_links_from_pdf(pdf_path):
12
  links = []
 
73
 
74
 
75
  #print(len(documents))
76
+ #semantic_split_docs = models.semanticChunker.split_documents(documents)
77
+ semantic_tuned_split_docs = models.semanticChunker_tuned.split_documents(documents)
78
+ #RCTS_split_docs = models.RCTS.split_documents(documents)
79
+ #print(len(semantic_split_docs))
80
+ print(len(semantic_tuned_split_docs))
81
+ #tokenizer = models.tuned_embeddings.client.tokenizer
82
+ #
83
+ #token_sizes = [len(tokenizer.encode(chunk)) for chunk in semantic_tuned_split_docs]
84
+
85
+ # Display the token sizes
86
+ #for idx, size in enumerate(token_sizes):
87
+ # print(f"Chunk {idx + 1}: {size} tokens")
88
+ #
89
+ #exit()
 
 
 
 
90
  #add them to the existing qdrant client
91
+ collection_name = "docs_from_ripped_urls_semantic_tuned"
92
 
93
  collections = models.qdrant_client.get_collections()
94
  collection_names = [collection.name for collection in collections.collections]
 
96
  if collection_name not in collection_names:
97
  models.qdrant_client.create_collection(
98
  collection_name=collection_name,
99
+ vectors_config=VectorParams(size=1024, distance="Cosine")
100
  )
101
 
102
  qdrant_vector_store = QdrantVectorStore(
103
  client=models.qdrant_client,
104
  collection_name=collection_name,
105
+ embedding=models.tuned_embeddings
106
  )
107
 
108
+ qdrant_vector_store.add_documents(semantic_tuned_split_docs)
109
 
110
 
111
 
models.py CHANGED
@@ -5,9 +5,11 @@ from langchain.callbacks.tracers import LangChainTracer
5
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
  from langchain_experimental.text_splitter import SemanticChunker
7
  from langchain_openai.embeddings import OpenAIEmbeddings
8
- from langchain_community.vectorstores import Qdrant
 
9
  from qdrant_client import QdrantClient
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
11
  import constants
12
  import os
13
 
@@ -66,10 +68,8 @@ gpt4o_mini = ChatOpenAI(
66
  )
67
 
68
  basic_embeddings = HuggingFaceEmbeddings(model_name="snowflake/snowflake-arctic-embed-l")
69
- #hkunlp_instructor_large = HuggingFaceInstructEmbeddings(
70
- # model_name = "hkunlp/instructor-large",
71
- # query_instruction="Represent the query for retrieval: "
72
- #)
73
 
74
  te3_small = OpenAIEmbeddings(api_key=constants.OPENAI_API_KEY, model="text-embedding-3-small")
75
 
@@ -78,9 +78,28 @@ semanticChunker = SemanticChunker(
78
  breakpoint_threshold_type="percentile"
79
  )
80
 
 
 
 
 
 
 
81
  RCTS = RecursiveCharacterTextSplitter(
82
  # Set a really small chunk size, just to show.
83
  chunk_size=500,
84
  chunk_overlap=25,
85
  length_function=len,
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
 
5
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
  from langchain_experimental.text_splitter import SemanticChunker
7
  from langchain_openai.embeddings import OpenAIEmbeddings
8
+ from langchain_qdrant import QdrantVectorStore, Qdrant
9
+ from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
10
  from qdrant_client import QdrantClient
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from langchain_cohere import CohereRerank
13
  import constants
14
  import os
15
 
 
68
  )
69
 
70
  basic_embeddings = HuggingFaceEmbeddings(model_name="snowflake/snowflake-arctic-embed-l")
71
+
72
+ tuned_embeddings = HuggingFaceEmbeddings(model_name="CoExperiences/snowflake-l-marketing-tuned")
 
 
73
 
74
  te3_small = OpenAIEmbeddings(api_key=constants.OPENAI_API_KEY, model="text-embedding-3-small")
75
 
 
78
  breakpoint_threshold_type="percentile"
79
  )
80
 
81
+ semanticChunker_tuned = SemanticChunker(
82
+ tuned_embeddings,
83
+ breakpoint_threshold_type="percentile",
84
+ breakpoint_threshold_amount=85
85
+ )
86
+
87
  RCTS = RecursiveCharacterTextSplitter(
88
  # Set a really small chunk size, just to show.
89
  chunk_size=500,
90
  chunk_overlap=25,
91
  length_function=len,
92
+ )
93
+
94
+ semantic_tuned_Qdrant_vs = QdrantVectorStore(
95
+ client=qdrant_client,
96
+ collection_name="docs_from_ripped_urls_semantic_tuned",
97
+ embedding=tuned_embeddings
98
+ )
99
+ semantic_tuned_retriever = semantic_tuned_Qdrant_vs.as_retriever(search_kwargs={"k" : 10})
100
+
101
+ #compression
102
+ compressor = CohereRerank(model="rerank-english-v3.0")
103
+ compression_retriever = ContextualCompressionRetriever(
104
+ base_compressor=compressor, base_retriever=semantic_tuned_retriever
105
  )
tuning/requirements.in ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain_openai
2
+ langchain_huggingface
3
+ langchain_core==0.2.38
4
+ langchain
5
+ langchain_community
6
+ langchain-text-splitters
7
+ faiss-cpu
8
+ unstructured==0.15.7
9
+ python-pptx==1.0.2
10
+ nltk==3.9.1
11
+ pyarrow
12
+ sentence_transformers
13
+ datasets
14
+ ragas
tuning/requirements.txt ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ aiohappyeyeballs==2.4.3
8
+ # via aiohttp
9
+ aiohttp==3.10.10
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ # langchain
14
+ # langchain-community
15
+ aiosignal==1.3.1
16
+ # via aiohttp
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ anyio==4.6.2.post1
20
+ # via
21
+ # httpx
22
+ # openai
23
+ appdirs==1.4.4
24
+ # via ragas
25
+ attrs==24.2.0
26
+ # via aiohttp
27
+ backoff==2.2.1
28
+ # via unstructured
29
+ beautifulsoup4==4.12.3
30
+ # via unstructured
31
+ certifi==2024.8.30
32
+ # via
33
+ # httpcore
34
+ # httpx
35
+ # requests
36
+ cffi==1.17.1
37
+ # via cryptography
38
+ chardet==5.2.0
39
+ # via unstructured
40
+ charset-normalizer==3.4.0
41
+ # via requests
42
+ click==8.1.7
43
+ # via nltk
44
+ cryptography==43.0.1
45
+ # via unstructured-client
46
+ dataclasses-json==0.6.7
47
+ # via
48
+ # langchain-community
49
+ # unstructured
50
+ datasets==3.0.1
51
+ # via
52
+ # -r requirements.in
53
+ # ragas
54
+ dill==0.3.8
55
+ # via
56
+ # datasets
57
+ # multiprocess
58
+ distro==1.9.0
59
+ # via openai
60
+ emoji==2.14.0
61
+ # via unstructured
62
+ eval-type-backport==0.2.0
63
+ # via unstructured-client
64
+ faiss-cpu==1.9.0
65
+ # via -r requirements.in
66
+ filelock==3.16.1
67
+ # via
68
+ # datasets
69
+ # huggingface-hub
70
+ # torch
71
+ # transformers
72
+ # triton
73
+ filetype==1.2.0
74
+ # via unstructured
75
+ frozenlist==1.4.1
76
+ # via
77
+ # aiohttp
78
+ # aiosignal
79
+ fsspec[http]==2024.6.1
80
+ # via
81
+ # datasets
82
+ # huggingface-hub
83
+ # torch
84
+ greenlet==3.1.1
85
+ # via sqlalchemy
86
+ h11==0.14.0
87
+ # via httpcore
88
+ httpcore==1.0.6
89
+ # via httpx
90
+ httpx==0.27.2
91
+ # via
92
+ # langsmith
93
+ # openai
94
+ # unstructured-client
95
+ huggingface-hub==0.26.0
96
+ # via
97
+ # datasets
98
+ # langchain-huggingface
99
+ # sentence-transformers
100
+ # tokenizers
101
+ # transformers
102
+ idna==3.10
103
+ # via
104
+ # anyio
105
+ # httpx
106
+ # requests
107
+ # yarl
108
+ jinja2==3.1.4
109
+ # via torch
110
+ jiter==0.6.1
111
+ # via openai
112
+ joblib==1.4.2
113
+ # via
114
+ # nltk
115
+ # scikit-learn
116
+ jsonpatch==1.33
117
+ # via langchain-core
118
+ jsonpath-python==1.0.6
119
+ # via unstructured-client
120
+ jsonpointer==3.0.0
121
+ # via jsonpatch
122
+ langchain==0.2.16
123
+ # via
124
+ # -r requirements.in
125
+ # langchain-community
126
+ # ragas
127
+ langchain-community==0.2.16
128
+ # via
129
+ # -r requirements.in
130
+ # ragas
131
+ langchain-core==0.2.38
132
+ # via
133
+ # -r requirements.in
134
+ # langchain
135
+ # langchain-community
136
+ # langchain-huggingface
137
+ # langchain-openai
138
+ # langchain-text-splitters
139
+ # ragas
140
+ langchain-huggingface==0.0.3
141
+ # via -r requirements.in
142
+ langchain-openai==0.1.23
143
+ # via
144
+ # -r requirements.in
145
+ # ragas
146
+ langchain-text-splitters==0.2.4
147
+ # via
148
+ # -r requirements.in
149
+ # langchain
150
+ langdetect==1.0.9
151
+ # via unstructured
152
+ langsmith==0.1.136
153
+ # via
154
+ # langchain
155
+ # langchain-community
156
+ # langchain-core
157
+ lxml==5.3.0
158
+ # via
159
+ # python-pptx
160
+ # unstructured
161
+ markupsafe==3.0.2
162
+ # via jinja2
163
+ marshmallow==3.23.0
164
+ # via dataclasses-json
165
+ mpmath==1.3.0
166
+ # via sympy
167
+ multidict==6.1.0
168
+ # via
169
+ # aiohttp
170
+ # yarl
171
+ multiprocess==0.70.16
172
+ # via datasets
173
+ mypy-extensions==1.0.0
174
+ # via typing-inspect
175
+ nest-asyncio==1.6.0
176
+ # via
177
+ # ragas
178
+ # unstructured-client
179
+ networkx==3.4.1
180
+ # via torch
181
+ nltk==3.9.1
182
+ # via
183
+ # -r requirements.in
184
+ # unstructured
185
+ numpy==1.26.4
186
+ # via
187
+ # datasets
188
+ # faiss-cpu
189
+ # langchain
190
+ # langchain-community
191
+ # pandas
192
+ # pyarrow
193
+ # ragas
194
+ # scikit-learn
195
+ # scipy
196
+ # transformers
197
+ # unstructured
198
+ nvidia-cublas-cu12==12.4.5.8
199
+ # via
200
+ # nvidia-cudnn-cu12
201
+ # nvidia-cusolver-cu12
202
+ # torch
203
+ nvidia-cuda-cupti-cu12==12.4.127
204
+ # via torch
205
+ nvidia-cuda-nvrtc-cu12==12.4.127
206
+ # via torch
207
+ nvidia-cuda-runtime-cu12==12.4.127
208
+ # via torch
209
+ nvidia-cudnn-cu12==9.1.0.70
210
+ # via torch
211
+ nvidia-cufft-cu12==11.2.1.3
212
+ # via torch
213
+ nvidia-curand-cu12==10.3.5.147
214
+ # via torch
215
+ nvidia-cusolver-cu12==11.6.1.9
216
+ # via torch
217
+ nvidia-cusparse-cu12==12.3.1.170
218
+ # via
219
+ # nvidia-cusolver-cu12
220
+ # torch
221
+ nvidia-nccl-cu12==2.21.5
222
+ # via torch
223
+ nvidia-nvjitlink-cu12==12.4.127
224
+ # via
225
+ # nvidia-cusolver-cu12
226
+ # nvidia-cusparse-cu12
227
+ # torch
228
+ nvidia-nvtx-cu12==12.4.127
229
+ # via torch
230
+ openai==1.52.0
231
+ # via
232
+ # langchain-openai
233
+ # ragas
234
+ orjson==3.10.7
235
+ # via langsmith
236
+ packaging==24.1
237
+ # via
238
+ # datasets
239
+ # faiss-cpu
240
+ # huggingface-hub
241
+ # langchain-core
242
+ # marshmallow
243
+ # transformers
244
+ pandas==2.2.3
245
+ # via datasets
246
+ pillow==11.0.0
247
+ # via
248
+ # python-pptx
249
+ # sentence-transformers
250
+ propcache==0.2.0
251
+ # via yarl
252
+ psutil==6.1.0
253
+ # via unstructured
254
+ pyarrow==17.0.0
255
+ # via
256
+ # -r requirements.in
257
+ # datasets
258
+ pycparser==2.22
259
+ # via cffi
260
+ pydantic==2.9.2
261
+ # via
262
+ # langchain
263
+ # langchain-core
264
+ # langsmith
265
+ # openai
266
+ # ragas
267
+ # unstructured-client
268
+ pydantic-core==2.23.4
269
+ # via pydantic
270
+ pypdf==5.0.1
271
+ # via unstructured-client
272
+ pysbd==0.3.4
273
+ # via ragas
274
+ python-dateutil==2.8.2
275
+ # via
276
+ # pandas
277
+ # unstructured-client
278
+ python-iso639==2024.4.27
279
+ # via unstructured
280
+ python-magic==0.4.27
281
+ # via unstructured
282
+ python-pptx==1.0.2
283
+ # via -r requirements.in
284
+ pytz==2024.2
285
+ # via pandas
286
+ pyyaml==6.0.2
287
+ # via
288
+ # datasets
289
+ # huggingface-hub
290
+ # langchain
291
+ # langchain-community
292
+ # langchain-core
293
+ # transformers
294
+ ragas==0.2.1
295
+ # via -r requirements.in
296
+ rapidfuzz==3.10.0
297
+ # via unstructured
298
+ regex==2024.9.11
299
+ # via
300
+ # nltk
301
+ # tiktoken
302
+ # transformers
303
+ requests==2.32.3
304
+ # via
305
+ # datasets
306
+ # huggingface-hub
307
+ # langchain
308
+ # langchain-community
309
+ # langsmith
310
+ # requests-toolbelt
311
+ # tiktoken
312
+ # transformers
313
+ # unstructured
314
+ requests-toolbelt==1.0.0
315
+ # via
316
+ # langsmith
317
+ # unstructured-client
318
+ safetensors==0.4.5
319
+ # via transformers
320
+ scikit-learn==1.5.2
321
+ # via sentence-transformers
322
+ scipy==1.14.1
323
+ # via
324
+ # scikit-learn
325
+ # sentence-transformers
326
+ sentence-transformers==3.2.0
327
+ # via
328
+ # -r requirements.in
329
+ # langchain-huggingface
330
+ six==1.16.0
331
+ # via
332
+ # langdetect
333
+ # python-dateutil
334
+ sniffio==1.3.1
335
+ # via
336
+ # anyio
337
+ # httpx
338
+ # openai
339
+ soupsieve==2.6
340
+ # via beautifulsoup4
341
+ sqlalchemy==2.0.36
342
+ # via
343
+ # langchain
344
+ # langchain-community
345
+ sympy==1.13.1
346
+ # via torch
347
+ tabulate==0.9.0
348
+ # via unstructured
349
+ tenacity==8.5.0
350
+ # via
351
+ # langchain
352
+ # langchain-community
353
+ # langchain-core
354
+ threadpoolctl==3.5.0
355
+ # via scikit-learn
356
+ tiktoken==0.8.0
357
+ # via
358
+ # langchain-openai
359
+ # ragas
360
+ tokenizers==0.20.1
361
+ # via
362
+ # langchain-huggingface
363
+ # transformers
364
+ torch==2.5.0
365
+ # via sentence-transformers
366
+ tqdm==4.66.5
367
+ # via
368
+ # datasets
369
+ # huggingface-hub
370
+ # nltk
371
+ # openai
372
+ # sentence-transformers
373
+ # transformers
374
+ # unstructured
375
+ transformers==4.45.2
376
+ # via
377
+ # langchain-huggingface
378
+ # sentence-transformers
379
+ triton==3.1.0
380
+ # via torch
381
+ typing-extensions==4.12.2
382
+ # via
383
+ # huggingface-hub
384
+ # langchain-core
385
+ # openai
386
+ # pydantic
387
+ # pydantic-core
388
+ # python-pptx
389
+ # sqlalchemy
390
+ # torch
391
+ # typing-inspect
392
+ # unstructured
393
+ typing-inspect==0.9.0
394
+ # via
395
+ # dataclasses-json
396
+ # unstructured-client
397
+ tzdata==2024.2
398
+ # via pandas
399
+ unstructured==0.15.7
400
+ # via -r requirements.in
401
+ unstructured-client==0.26.1
402
+ # via unstructured
403
+ urllib3==2.2.3
404
+ # via requests
405
+ wrapt==1.16.0
406
+ # via unstructured
407
+ xlsxwriter==3.2.0
408
+ # via python-pptx
409
+ xxhash==3.5.0
410
+ # via datasets
411
+ yarl==1.15.4
412
+ # via aiohttp
tuning/tuning_embeddings_sandbox.ipynb ADDED
The diff for this file is too large to render. See raw diff