Spaces:
Running
Running
oceansweep
commited on
Commit
•
d988bf2
1
Parent(s):
e55576b
Upload RAG_Library_2.py
Browse files
App_Function_Libraries/RAG/RAG_Library_2.py
CHANGED
@@ -128,75 +128,74 @@ search_functions = {
|
|
128 |
|
129 |
# RAG Search with keyword filtering
|
130 |
# FIXME - Update each called function to support modifiable top-k results
|
131 |
-
def enhanced_rag_pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
"""
|
133 |
Perform full text search across specified database type.
|
134 |
|
135 |
Args:
|
136 |
query: Search query string
|
137 |
api_choice: API to use for generating the response
|
138 |
-
fts_top_k: Maximum number of results to return
|
139 |
keywords: Optional list of media IDs to filter results
|
140 |
-
|
|
|
|
|
141 |
|
142 |
Returns:
|
143 |
Dictionary containing search results with content
|
144 |
"""
|
145 |
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
|
146 |
start_time = time.time()
|
|
|
147 |
try:
|
148 |
# Load embedding provider from config, or fallback to 'openai'
|
149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
150 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
151 |
|
152 |
-
#
|
153 |
-
|
154 |
-
logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}")
|
155 |
|
156 |
-
|
|
|
|
|
|
|
157 |
|
158 |
-
# Fetch relevant IDs based on keywords if keywords are provided
|
159 |
-
if keyword_list:
|
160 |
try:
|
161 |
for db_type in database_types:
|
162 |
if db_type == "Media DB":
|
163 |
-
|
164 |
-
relevant_ids[db_type] =
|
165 |
elif db_type == "RAG Chat":
|
166 |
-
conversations,
|
167 |
-
relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
|
168 |
elif db_type == "RAG Notes":
|
169 |
-
notes,
|
170 |
-
relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
|
171 |
elif db_type == "Character Chat":
|
172 |
-
relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
|
173 |
elif db_type == "Character Cards":
|
174 |
-
relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
|
175 |
else:
|
176 |
logging.error(f"Unsupported database type: {db_type}")
|
177 |
|
178 |
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
179 |
except Exception as e:
|
180 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
|
|
181 |
else:
|
182 |
-
relevant_ids = None
|
183 |
-
|
184 |
-
# Prepare relevant IDs for each database type
|
185 |
-
relevant_ids_dict = {}
|
186 |
-
if relevant_ids:
|
187 |
-
for db_type in database_types:
|
188 |
-
if db_type in relevant_ids and relevant_ids[db_type]:
|
189 |
-
relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
|
190 |
-
else:
|
191 |
-
relevant_ids_dict[db_type] = None
|
192 |
-
else:
|
193 |
-
relevant_ids_dict = {db_type: None for db_type in database_types}
|
194 |
|
195 |
-
# Perform vector search
|
196 |
vector_results = []
|
197 |
for db_type in database_types:
|
198 |
try:
|
199 |
-
db_relevant_ids =
|
200 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
201 |
vector_results.extend(results)
|
202 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
@@ -217,7 +216,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
217 |
fts_results = []
|
218 |
for db_type in database_types:
|
219 |
try:
|
220 |
-
db_relevant_ids = relevant_ids.get(db_type)
|
221 |
db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
|
222 |
fts_results.extend(db_results)
|
223 |
logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
|
@@ -233,12 +232,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
233 |
# Combine results
|
234 |
all_results = vector_results + fts_results
|
235 |
|
|
|
|
|
|
|
236 |
if apply_re_ranking and all_results:
|
237 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
238 |
-
|
239 |
-
# FIXME - specify model + add param to modify at call time
|
240 |
-
# FIXME - add option to set a custom top X results
|
241 |
-
# You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
|
242 |
if all_results:
|
243 |
ranker = Ranker()
|
244 |
|
@@ -273,7 +272,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
273 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
274 |
}
|
275 |
|
276 |
-
#
|
277 |
pipeline_duration = time.time() - start_time
|
278 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
279 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
@@ -284,7 +283,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
284 |
}
|
285 |
|
286 |
except Exception as e:
|
287 |
-
# Metrics
|
288 |
log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
|
289 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
290 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
|
|
128 |
|
129 |
# RAG Search with keyword filtering
|
130 |
# FIXME - Update each called function to support modifiable top-k results
|
131 |
+
def enhanced_rag_pipeline(
|
132 |
+
query: str,
|
133 |
+
api_choice: str,
|
134 |
+
keywords: Optional[str] = None,
|
135 |
+
fts_top_k: int = 10,
|
136 |
+
apply_re_ranking: bool = True,
|
137 |
+
database_types: List[str] = ["Media DB"]
|
138 |
+
) -> Dict[str, Any]:
|
139 |
"""
|
140 |
Perform full text search across specified database type.
|
141 |
|
142 |
Args:
|
143 |
query: Search query string
|
144 |
api_choice: API to use for generating the response
|
|
|
145 |
keywords: Optional list of media IDs to filter results
|
146 |
+
fts_top_k: Maximum number of results to return
|
147 |
+
apply_re_ranking: Whether to apply re-ranking to results
|
148 |
+
database_types: Type of database to search
|
149 |
|
150 |
Returns:
|
151 |
Dictionary containing search results with content
|
152 |
"""
|
153 |
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
|
154 |
start_time = time.time()
|
155 |
+
|
156 |
try:
|
157 |
# Load embedding provider from config, or fallback to 'openai'
|
158 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
159 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
160 |
|
161 |
+
# Initialize relevant IDs dictionary
|
162 |
+
relevant_ids: Dict[str, Optional[List[str]]] = {}
|
|
|
163 |
|
164 |
+
# Process keywords if provided
|
165 |
+
if keywords:
|
166 |
+
keyword_list = [k.strip().lower() for k in keywords.split(',')]
|
167 |
+
logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
|
168 |
|
|
|
|
|
169 |
try:
|
170 |
for db_type in database_types:
|
171 |
if db_type == "Media DB":
|
172 |
+
media_ids = fetch_relevant_media_ids(keyword_list)
|
173 |
+
relevant_ids[db_type] = [str(id_) for id_ in media_ids]
|
174 |
elif db_type == "RAG Chat":
|
175 |
+
conversations, _, _ = search_conversations_by_keywords(keywords=keyword_list)
|
176 |
+
relevant_ids[db_type] = [str(conv['conversation_id']) for conv in conversations]
|
177 |
elif db_type == "RAG Notes":
|
178 |
+
notes, _, _ = get_notes_by_keywords(keyword_list)
|
179 |
+
relevant_ids[db_type] = [str(note_id) for note_id, _, _, _ in notes]
|
180 |
elif db_type == "Character Chat":
|
181 |
+
relevant_ids[db_type] = [str(id_) for id_ in fetch_keywords_for_chats(keyword_list)]
|
182 |
elif db_type == "Character Cards":
|
183 |
+
relevant_ids[db_type] = [str(id_) for id_ in fetch_character_ids_by_keywords(keyword_list)]
|
184 |
else:
|
185 |
logging.error(f"Unsupported database type: {db_type}")
|
186 |
|
187 |
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
188 |
except Exception as e:
|
189 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
190 |
+
relevant_ids = {db_type: None for db_type in database_types}
|
191 |
else:
|
192 |
+
relevant_ids = {db_type: None for db_type in database_types}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
# Perform vector search
|
195 |
vector_results = []
|
196 |
for db_type in database_types:
|
197 |
try:
|
198 |
+
db_relevant_ids = relevant_ids.get(db_type)
|
199 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
200 |
vector_results.extend(results)
|
201 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
|
216 |
fts_results = []
|
217 |
for db_type in database_types:
|
218 |
try:
|
219 |
+
db_relevant_ids = relevant_ids.get(db_type)
|
220 |
db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
|
221 |
fts_results.extend(db_results)
|
222 |
logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
|
|
|
232 |
# Combine results
|
233 |
all_results = vector_results + fts_results
|
234 |
|
235 |
+
# FIXME - specify model + add param to modify at call time
|
236 |
+
# You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
|
237 |
+
# Apply re-ranking if enabled and results exist
|
238 |
if apply_re_ranking and all_results:
|
239 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
240 |
+
|
|
|
|
|
|
|
241 |
if all_results:
|
242 |
ranker = Ranker()
|
243 |
|
|
|
272 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
273 |
}
|
274 |
|
275 |
+
# Log metrics
|
276 |
pipeline_duration = time.time() - start_time
|
277 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
278 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
|
|
283 |
}
|
284 |
|
285 |
except Exception as e:
|
|
|
286 |
log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
|
287 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
288 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|