oceansweep commited on
Commit
e55576b
1 Parent(s): b3b2a0b

Upload 2 files

Browse files
App_Function_Libraries/RAG/RAG_Library_2.py CHANGED
@@ -147,8 +147,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
147
  try:
148
  # Load embedding provider from config, or fallback to 'openai'
149
  embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
150
-
151
- # Log the provider used
152
  logging.debug(f"Using embedding provider: {embedding_provider}")
153
 
154
  # Process keywords if provided
@@ -164,61 +162,41 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
164
  if db_type == "Media DB":
165
  relevant_media_ids = fetch_relevant_media_ids(keyword_list)
166
  relevant_ids[db_type] = relevant_media_ids
167
- logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}")
168
-
169
  elif db_type == "RAG Chat":
170
- conversations, total_pages, total_count = search_conversations_by_keywords(
171
- keywords=keyword_list)
172
- relevant_conversation_ids = [conv['conversation_id'] for conv in conversations]
173
- relevant_ids[db_type] = relevant_conversation_ids
174
- logging.debug(
175
- f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}")
176
-
177
  elif db_type == "RAG Notes":
178
  notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
179
- relevant_note_ids = [note_id for note_id, _, _, _ in notes] # Unpack note_id from the tuple
180
- relevant_ids[db_type] = relevant_note_ids
181
- logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}")
182
-
183
  elif db_type == "Character Chat":
184
- relevant_chat_ids = fetch_keywords_for_chats(keyword_list)
185
- relevant_ids[db_type] = relevant_chat_ids
186
- logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}")
187
-
188
  elif db_type == "Character Cards":
189
- # Assuming we have a function to fetch character IDs by keywords
190
- relevant_character_ids = fetch_character_ids_by_keywords(keyword_list)
191
- relevant_ids[db_type] = relevant_character_ids
192
- logging.debug(
193
- f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}")
194
-
195
  else:
196
  logging.error(f"Unsupported database type: {db_type}")
197
 
 
198
  except Exception as e:
199
  logging.error(f"Error fetching relevant IDs: {str(e)}")
200
  else:
201
  relevant_ids = None
202
 
203
- # Extract relevant media IDs for each selected DB
204
- # Prepare a dict to hold relevant_media_ids per DB
205
- relevant_media_ids_dict = {}
206
  if relevant_ids:
207
  for db_type in database_types:
208
- relevant_media_ids = relevant_ids.get(db_type, None)
209
- if relevant_media_ids:
210
- # Convert to List[str] if not None
211
- relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids]
212
  else:
213
- relevant_media_ids_dict[db_type] = None
214
  else:
215
- relevant_media_ids_dict = {db_type: None for db_type in database_types}
216
 
217
  # Perform vector search for all selected databases
218
  vector_results = []
219
  for db_type in database_types:
220
  try:
221
- db_relevant_ids = relevant_media_ids_dict.get(db_type)
222
  results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
223
  vector_results.extend(results)
224
  logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
@@ -227,8 +205,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
227
 
228
  # Perform vector search
229
  # FIXME
230
- vector_results = perform_vector_search(query, relevant_media_ids)
231
- logging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")
232
 
233
  # Perform full-text search
234
  #v1
@@ -246,7 +224,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
246
  except Exception as e:
247
  logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
248
 
249
- logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
250
  logging.debug(
251
  "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
252
  [str(item) for item in fts_results]) + "\n"
@@ -255,7 +233,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
255
  # Combine results
256
  all_results = vector_results + fts_results
257
 
258
- if apply_re_ranking:
259
  logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
260
  # FIXME - add option to use re-ranking at call time
261
  # FIXME - specify model + add param to modify at call time
@@ -282,7 +260,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
282
 
283
  # Extract content from results (top fts_top_k by default)
284
  context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
285
- logging.debug(f"Context length: {len(context)}")
286
  logging.debug(f"Context: {context[:200]}")
287
 
288
  # Generate answer using the selected API
@@ -294,10 +272,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
294
  "answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
295
  "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
296
  }
 
297
  # Metrics
298
  pipeline_duration = time.time() - start_time
299
  log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
300
  log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
 
301
  return {
302
  "answer": answer,
303
  "context": context
 
147
  try:
148
  # Load embedding provider from config, or fallback to 'openai'
149
  embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
 
 
150
  logging.debug(f"Using embedding provider: {embedding_provider}")
151
 
152
  # Process keywords if provided
 
162
  if db_type == "Media DB":
163
  relevant_media_ids = fetch_relevant_media_ids(keyword_list)
164
  relevant_ids[db_type] = relevant_media_ids
 
 
165
  elif db_type == "RAG Chat":
166
+ conversations, total_pages, total_count = search_conversations_by_keywords(keywords=keyword_list)
167
+ relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
 
 
 
 
 
168
  elif db_type == "RAG Notes":
169
  notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
170
+ relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
 
 
 
171
  elif db_type == "Character Chat":
172
+ relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
 
 
 
173
  elif db_type == "Character Cards":
174
+ relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
 
 
 
 
 
175
  else:
176
  logging.error(f"Unsupported database type: {db_type}")
177
 
178
+ logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
179
  except Exception as e:
180
  logging.error(f"Error fetching relevant IDs: {str(e)}")
181
  else:
182
  relevant_ids = None
183
 
184
+ # Prepare relevant IDs for each database type
185
+ relevant_ids_dict = {}
 
186
  if relevant_ids:
187
  for db_type in database_types:
188
+ if db_type in relevant_ids and relevant_ids[db_type]:
189
+ relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
 
 
190
  else:
191
+ relevant_ids_dict[db_type] = None
192
  else:
193
+ relevant_ids_dict = {db_type: None for db_type in database_types}
194
 
195
  # Perform vector search for all selected databases
196
  vector_results = []
197
  for db_type in database_types:
198
  try:
199
+ db_relevant_ids = relevant_ids_dict.get(db_type)
200
  results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
201
  vector_results.extend(results)
202
  logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
 
205
 
206
  # Perform vector search
207
  # FIXME
208
+ #vector_results = perform_vector_search(query, relevant_media_ids)
209
+ #ogging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")
210
 
211
  # Perform full-text search
212
  #v1
 
224
  except Exception as e:
225
  logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
226
 
227
+ #logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
228
  logging.debug(
229
  "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
230
  [str(item) for item in fts_results]) + "\n"
 
233
  # Combine results
234
  all_results = vector_results + fts_results
235
 
236
+ if apply_re_ranking and all_results:
237
  logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
238
  # FIXME - add option to use re-ranking at call time
239
  # FIXME - specify model + add param to modify at call time
 
260
 
261
  # Extract content from results (top fts_top_k by default)
262
  context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
263
+ #logging.debug(f"Context length: {len(context)}")
264
  logging.debug(f"Context: {context[:200]}")
265
 
266
  # Generate answer using the selected API
 
272
  "answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
273
  "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
274
  }
275
+
276
  # Metrics
277
  pipeline_duration = time.time() - start_time
278
  log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
279
  log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
280
+
281
  return {
282
  "answer": answer,
283
  "context": context