XThomasBU commited on
Commit
a2ac5f7
1 Parent(s): e5cd1d3

more modularization for vectorestore and retriever

Browse files
code/modules/chat/llm_tutor.py CHANGED
@@ -10,7 +10,7 @@ from modules.chat.helpers import get_prompt
10
  from modules.chat.chat_model_loader import ChatModelLoader
11
  from modules.vectorstore.store_manager import VectorStoreManager
12
 
13
- from modules.retriever import FaissRetriever, ChromaRetriever
14
 
15
  from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
@@ -159,16 +159,7 @@ class LLMTutor:
159
  # Retrieval QA Chain
160
  def retrieval_qa_chain(self, llm, prompt, db):
161
 
162
- if self.config["vectorstore"]["db_option"] == "FAISS":
163
- retriever = FaissRetriever().return_retriever(db, self.config)
164
-
165
- elif self.config["vectorstore"]["db_option"] == "Chroma":
166
- retriever = ChromaRetriever().return_retriever(db, self.config)
167
-
168
- elif self.config["vectorstore"]["db_option"] == "RAGatouille":
169
- retriever = db.as_langchain_retriever(
170
- k=self.config["vectorstore"]["search_top_k"]
171
- )
172
 
173
  if self.config["llm_params"]["use_history"]:
174
  memory = ConversationBufferWindowMemory(
 
10
  from modules.chat.chat_model_loader import ChatModelLoader
11
  from modules.vectorstore.store_manager import VectorStoreManager
12
 
13
+ from modules.retriever import Retriever
14
 
15
  from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
 
159
  # Retrieval QA Chain
160
  def retrieval_qa_chain(self, llm, prompt, db):
161
 
162
+ retriever = Retriever(self.config)._return_retriever(db)
 
 
 
 
 
 
 
 
 
163
 
164
  if self.config["llm_params"]["use_history"]:
165
  memory = ConversationBufferWindowMemory(
code/modules/config/config.yml CHANGED
@@ -7,7 +7,7 @@ vectorstore:
7
  data_path: '../storage/data' # str
8
  url_file_path: '../storage/data/urls.txt' # str
9
  expand_urls: False # bool
10
- db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
11
  db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
 
7
  data_path: '../storage/data' # str
8
  url_file_path: '../storage/data/urls.txt' # str
9
  expand_urls: False # bool
10
+ db_option : 'Chroma' # str [FAISS, Chroma, RAGatouille]
11
  db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
code/modules/retriever/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
  from .faiss_retriever import FaissRetriever
2
  from .chroma_retriever import ChromaRetriever
 
 
 
1
  from .faiss_retriever import FaissRetriever
2
  from .chroma_retriever import ChromaRetriever
3
+ from .colbert_retriever import ColbertRetriever
4
+ from .retriever import Retriever
code/modules/retriever/colbert_retriever.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseRetriever
2
+
3
+
4
+ class ColbertRetriever(BaseRetriever):
5
+ def __init__(self):
6
+ pass
7
+
8
+ def return_retriever(self, db, config):
9
+ retriever = db.as_retriever()
10
+ return retriever
code/modules/retriever/retriever.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.retriever.faiss_retriever import FaissRetriever
2
+ from modules.retriever.chroma_retriever import ChromaRetriever
3
+ from modules.retriever.colbert_retriever import ColbertRetriever
4
+
5
+
6
+ class Retriever:
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._create_retriever()
10
+
11
+ def _create_retriever(self):
12
+ if self.config["vectorstore"]["db_option"] == "FAISS":
13
+ self.retriever = FaissRetriever()
14
+ elif self.config["vectorstore"]["db_option"] == "Chroma":
15
+ self.retriever = ChromaRetriever()
16
+ elif self.config["vectorstore"]["db_option"] == "RAGatouille":
17
+ self.retriever = ColbertRetriever()
18
+ else:
19
+ raise ValueError(
20
+ "Invalid db_option: {}".format(self.config["vectorstore"]["db_option"])
21
+ )
22
+
23
+ def _return_retriever(self, db):
24
+ return self.retriever.return_retriever(db, self.config)
code/modules/vectorstore/store_manager.py CHANGED
@@ -1,6 +1,4 @@
1
- from modules.vectorstore.faiss import FaissVectorStore
2
- from modules.vectorstore.chroma import ChromaVectorStore
3
- from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.helpers import *
5
  from modules.dataloader.webpage_crawler import WebpageCrawler
6
  from modules.dataloader.data_loader import DataLoader
@@ -15,7 +13,6 @@ import asyncio
15
  class VectorStoreManager:
16
  def __init__(self, config, logger=None):
17
  self.config = config
18
- self.db_option = config["vectorstore"]["db_option"]
19
  self.document_names = None
20
 
21
  # Set up logging to both console and a file
@@ -47,9 +44,12 @@ class VectorStoreManager:
47
 
48
  self.webpage_crawler = WebpageCrawler()
49
 
 
 
50
  self.logger.info("VectorDB instance instantiated")
51
 
52
  def load_files(self):
 
53
  files = os.listdir(self.config["vectorstore"]["data_path"])
54
  files = [
55
  os.path.join(self.config["vectorstore"]["data_path"], file)
@@ -71,6 +71,7 @@ class VectorStoreManager:
71
  return files, urls
72
 
73
  def create_embedding_model(self):
 
74
  self.logger.info("Creating embedding function")
75
  embedding_model_loader = EmbeddingModelLoader(self.config)
76
  embedding_model = embedding_model_loader.load_embedding_model()
@@ -83,22 +84,23 @@ class VectorStoreManager:
83
  documents: list,
84
  document_metadata: list,
85
  ):
86
- if self.db_option in ["FAISS", "Chroma"]:
87
  self.embedding_model = self.create_embedding_model()
88
 
89
  self.logger.info("Initializing vector_db")
90
- self.logger.info("\tUsing {} as db_option".format(self.db_option))
91
- if self.db_option == "FAISS":
92
- self.vector_db = FaissVectorStore(self.config)
93
- self.vector_db.create_database(document_chunks, self.embedding_model)
94
- elif self.db_option == "Chroma":
95
- self.vector_db = ChromaVectorStore(self.config)
96
- self.vector_db.create_database(document_chunks, self.embedding_model)
97
- elif self.db_option == "RAGatouille":
98
- self.vector_db = ColbertVectorStore(self.config)
99
- self.vector_db.create_database(documents, document_names, document_metadata)
100
 
101
  def create_database(self):
 
102
  start_time = time.time() # Start time for creating database
103
  data_loader = DataLoader(self.config, self.logger)
104
  self.logger.info("Loading data")
@@ -126,18 +128,11 @@ class VectorStoreManager:
126
  )
127
 
128
  def load_database(self):
 
129
  start_time = time.time() # Start time for loading database
130
- if self.db_option in ["FAISS", "Chroma"]:
131
  self.embedding_model = self.create_embedding_model()
132
- if self.db_option == "FAISS":
133
- self.vector_db = FaissVectorStore(self.config)
134
- self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
135
- elif self.db_option == "Chroma":
136
- self.vector_db = ChromaVectorStore(self.config)
137
- self.loaded_vector_db = self.vector_db.load_database(self.embedding_model)
138
- elif self.db_option == "RAGatouille":
139
- self.vector_db = ColbertVectorStore(self.config)
140
- self.loaded_vector_db = self.vector_db.load_database()
141
  end_time = time.time() # End time for loading database
142
  self.logger.info(
143
  f"Time taken to load database: {end_time - start_time} seconds"
 
1
+ from modules.vectorstore.vectorstore import VectorStore
 
 
2
  from modules.vectorstore.helpers import *
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
 
13
  class VectorStoreManager:
14
  def __init__(self, config, logger=None):
15
  self.config = config
 
16
  self.document_names = None
17
 
18
  # Set up logging to both console and a file
 
44
 
45
  self.webpage_crawler = WebpageCrawler()
46
 
47
+ self.vector_db = VectorStore(self.config)
48
+
49
  self.logger.info("VectorDB instance instantiated")
50
 
51
  def load_files(self):
52
+
53
  files = os.listdir(self.config["vectorstore"]["data_path"])
54
  files = [
55
  os.path.join(self.config["vectorstore"]["data_path"], file)
 
71
  return files, urls
72
 
73
  def create_embedding_model(self):
74
+
75
  self.logger.info("Creating embedding function")
76
  embedding_model_loader = EmbeddingModelLoader(self.config)
77
  embedding_model = embedding_model_loader.load_embedding_model()
 
84
  documents: list,
85
  document_metadata: list,
86
  ):
87
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
88
  self.embedding_model = self.create_embedding_model()
89
 
90
  self.logger.info("Initializing vector_db")
91
+ self.logger.info(
92
+ "\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"])
93
+ )
94
+ self.vector_db._create_database(
95
+ document_chunks,
96
+ document_names,
97
+ documents,
98
+ document_metadata,
99
+ self.embedding_model,
100
+ )
101
 
102
  def create_database(self):
103
+
104
  start_time = time.time() # Start time for creating database
105
  data_loader = DataLoader(self.config, self.logger)
106
  self.logger.info("Loading data")
 
128
  )
129
 
130
  def load_database(self):
131
+
132
  start_time = time.time() # Start time for loading database
133
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
134
  self.embedding_model = self.create_embedding_model()
135
+ self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
 
 
 
 
 
 
 
 
136
  end_time = time.time() # End time for loading database
137
  self.logger.info(
138
  f"Time taken to load database: {end_time - start_time} seconds"
code/modules/vectorstore/vectorstore.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.faiss import FaissVectorStore
2
+ from modules.vectorstore.chroma import ChromaVectorStore
3
+ from modules.vectorstore.colbert import ColbertVectorStore
4
+
5
+
6
+ class VectorStore:
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self.vectorstore = None
10
+
11
+ def _create_database(
12
+ self,
13
+ document_chunks,
14
+ document_names,
15
+ documents,
16
+ document_metadata,
17
+ embedding_model,
18
+ ):
19
+ if self.config["vectorstore"]["db_option"] == "FAISS":
20
+ self.vectorstore = FaissVectorStore(self.config)
21
+ self.vectorstore.create_database(document_chunks, embedding_model)
22
+ elif self.config["vectorstore"]["db_option"] == "Chroma":
23
+ self.vectorstore = ChromaVectorStore(self.config)
24
+ self.vectorstore.create_database(document_chunks, embedding_model)
25
+ elif self.config["vectorstore"]["db_option"] == "RAGatouille":
26
+ self.vectorstore = ColbertVectorStore(self.config)
27
+ self.vectorstore.create_database(
28
+ documents, document_names, document_metadata
29
+ )
30
+ else:
31
+ raise ValueError(
32
+ "Invalid db_option: {}".format(self.config["vectorstore"]["db_option"])
33
+ )
34
+
35
+ def _load_database(self, embedding_model):
36
+ if self.config["vectorstore"]["db_option"] == "FAISS":
37
+ self.vectorstore = FaissVectorStore(self.config)
38
+ return self.vectorstore.load_database(embedding_model)
39
+ elif self.config["vectorstore"]["db_option"] == "Chroma":
40
+ self.vectorstore = ChromaVectorStore(self.config)
41
+ return self.vectorstore.load_database(embedding_model)
42
+ elif self.config["vectorstore"]["db_option"] == "RAGatouille":
43
+ self.vectorstore = ColbertVectorStore(self.config)
44
+ return self.vectorstore.load_database()
45
+
46
+ def _as_retriever(self):
47
+ return self.vectorstore.as_retriever()
48
+
49
+ def _get_vectorstore(self):
50
+ return self.vectorstore