Spaces:
Running
Running
oceansweep
commited on
Commit
•
e55576b
1
Parent(s):
b3b2a0b
Upload 2 files
Browse files
App_Function_Libraries/RAG/RAG_Library_2.py
CHANGED
@@ -147,8 +147,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
147 |
try:
|
148 |
# Load embedding provider from config, or fallback to 'openai'
|
149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
150 |
-
|
151 |
-
# Log the provider used
|
152 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
153 |
|
154 |
# Process keywords if provided
|
@@ -164,61 +162,41 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
164 |
if db_type == "Media DB":
|
165 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list)
|
166 |
relevant_ids[db_type] = relevant_media_ids
|
167 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}")
|
168 |
-
|
169 |
elif db_type == "RAG Chat":
|
170 |
-
conversations, total_pages, total_count = search_conversations_by_keywords(
|
171 |
-
|
172 |
-
relevant_conversation_ids = [conv['conversation_id'] for conv in conversations]
|
173 |
-
relevant_ids[db_type] = relevant_conversation_ids
|
174 |
-
logging.debug(
|
175 |
-
f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}")
|
176 |
-
|
177 |
elif db_type == "RAG Notes":
|
178 |
notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
|
179 |
-
|
180 |
-
relevant_ids[db_type] = relevant_note_ids
|
181 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}")
|
182 |
-
|
183 |
elif db_type == "Character Chat":
|
184 |
-
|
185 |
-
relevant_ids[db_type] = relevant_chat_ids
|
186 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}")
|
187 |
-
|
188 |
elif db_type == "Character Cards":
|
189 |
-
|
190 |
-
relevant_character_ids = fetch_character_ids_by_keywords(keyword_list)
|
191 |
-
relevant_ids[db_type] = relevant_character_ids
|
192 |
-
logging.debug(
|
193 |
-
f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}")
|
194 |
-
|
195 |
else:
|
196 |
logging.error(f"Unsupported database type: {db_type}")
|
197 |
|
|
|
198 |
except Exception as e:
|
199 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
200 |
else:
|
201 |
relevant_ids = None
|
202 |
|
203 |
-
#
|
204 |
-
|
205 |
-
relevant_media_ids_dict = {}
|
206 |
if relevant_ids:
|
207 |
for db_type in database_types:
|
208 |
-
|
209 |
-
|
210 |
-
# Convert to List[str] if not None
|
211 |
-
relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids]
|
212 |
else:
|
213 |
-
|
214 |
else:
|
215 |
-
|
216 |
|
217 |
# Perform vector search for all selected databases
|
218 |
vector_results = []
|
219 |
for db_type in database_types:
|
220 |
try:
|
221 |
-
db_relevant_ids =
|
222 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
223 |
vector_results.extend(results)
|
224 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
@@ -227,8 +205,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
227 |
|
228 |
# Perform vector search
|
229 |
# FIXME
|
230 |
-
vector_results = perform_vector_search(query, relevant_media_ids)
|
231 |
-
|
232 |
|
233 |
# Perform full-text search
|
234 |
#v1
|
@@ -246,7 +224,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
246 |
except Exception as e:
|
247 |
logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
|
248 |
|
249 |
-
logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
|
250 |
logging.debug(
|
251 |
"\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
|
252 |
[str(item) for item in fts_results]) + "\n"
|
@@ -255,7 +233,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
255 |
# Combine results
|
256 |
all_results = vector_results + fts_results
|
257 |
|
258 |
-
if apply_re_ranking:
|
259 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
260 |
# FIXME - add option to use re-ranking at call time
|
261 |
# FIXME - specify model + add param to modify at call time
|
@@ -282,7 +260,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
282 |
|
283 |
# Extract content from results (top fts_top_k by default)
|
284 |
context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
|
285 |
-
logging.debug(f"Context length: {len(context)}")
|
286 |
logging.debug(f"Context: {context[:200]}")
|
287 |
|
288 |
# Generate answer using the selected API
|
@@ -294,10 +272,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
294 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
295 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
296 |
}
|
|
|
297 |
# Metrics
|
298 |
pipeline_duration = time.time() - start_time
|
299 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
300 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
|
|
301 |
return {
|
302 |
"answer": answer,
|
303 |
"context": context
|
|
|
147 |
try:
|
148 |
# Load embedding provider from config, or fallback to 'openai'
|
149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
|
|
|
|
150 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
151 |
|
152 |
# Process keywords if provided
|
|
|
162 |
if db_type == "Media DB":
|
163 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list)
|
164 |
relevant_ids[db_type] = relevant_media_ids
|
|
|
|
|
165 |
elif db_type == "RAG Chat":
|
166 |
+
conversations, total_pages, total_count = search_conversations_by_keywords(keywords=keyword_list)
|
167 |
+
relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
|
|
|
|
|
|
|
|
|
|
|
168 |
elif db_type == "RAG Notes":
|
169 |
notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
|
170 |
+
relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
|
|
|
|
|
|
|
171 |
elif db_type == "Character Chat":
|
172 |
+
relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
|
|
|
|
|
|
|
173 |
elif db_type == "Character Cards":
|
174 |
+
relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
|
|
|
|
|
|
|
|
|
|
|
175 |
else:
|
176 |
logging.error(f"Unsupported database type: {db_type}")
|
177 |
|
178 |
+
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
179 |
except Exception as e:
|
180 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
181 |
else:
|
182 |
relevant_ids = None
|
183 |
|
184 |
+
# Prepare relevant IDs for each database type
|
185 |
+
relevant_ids_dict = {}
|
|
|
186 |
if relevant_ids:
|
187 |
for db_type in database_types:
|
188 |
+
if db_type in relevant_ids and relevant_ids[db_type]:
|
189 |
+
relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
|
|
|
|
|
190 |
else:
|
191 |
+
relevant_ids_dict[db_type] = None
|
192 |
else:
|
193 |
+
relevant_ids_dict = {db_type: None for db_type in database_types}
|
194 |
|
195 |
# Perform vector search for all selected databases
|
196 |
vector_results = []
|
197 |
for db_type in database_types:
|
198 |
try:
|
199 |
+
db_relevant_ids = relevant_ids_dict.get(db_type)
|
200 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
201 |
vector_results.extend(results)
|
202 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
|
205 |
|
206 |
# Perform vector search
|
207 |
# FIXME
|
208 |
+
#vector_results = perform_vector_search(query, relevant_media_ids)
|
209 |
+
#ogging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")
|
210 |
|
211 |
# Perform full-text search
|
212 |
#v1
|
|
|
224 |
except Exception as e:
|
225 |
logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
|
226 |
|
227 |
+
#logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
|
228 |
logging.debug(
|
229 |
"\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
|
230 |
[str(item) for item in fts_results]) + "\n"
|
|
|
233 |
# Combine results
|
234 |
all_results = vector_results + fts_results
|
235 |
|
236 |
+
if apply_re_ranking and all_results:
|
237 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
238 |
# FIXME - add option to use re-ranking at call time
|
239 |
# FIXME - specify model + add param to modify at call time
|
|
|
260 |
|
261 |
# Extract content from results (top fts_top_k by default)
|
262 |
context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
|
263 |
+
#logging.debug(f"Context length: {len(context)}")
|
264 |
logging.debug(f"Context: {context[:200]}")
|
265 |
|
266 |
# Generate answer using the selected API
|
|
|
272 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
273 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
274 |
}
|
275 |
+
|
276 |
# Metrics
|
277 |
pipeline_duration = time.time() - start_time
|
278 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
279 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
280 |
+
|
281 |
return {
|
282 |
"answer": answer,
|
283 |
"context": context
|