oceansweep commited on
Commit
d988bf2
1 Parent(s): e55576b

Upload RAG_Library_2.py

Browse files
App_Function_Libraries/RAG/RAG_Library_2.py CHANGED
@@ -128,75 +128,74 @@ search_functions = {
128
 
129
  # RAG Search with keyword filtering
130
  # FIXME - Update each called function to support modifiable top-k results
131
- def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts_top_k=10, apply_re_ranking=True, database_types: List[str] = "Media DB") -> Dict[str, Any]:
 
 
 
 
 
 
 
132
  """
133
  Perform full text search across specified database type.
134
 
135
  Args:
136
  query: Search query string
137
  api_choice: API to use for generating the response
138
- fts_top_k: Maximum number of results to return
139
  keywords: Optional list of media IDs to filter results
140
- database_types: Type of database to search ("Media DB", "RAG Chat", or "Character Chat")
 
 
141
 
142
  Returns:
143
  Dictionary containing search results with content
144
  """
145
  log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
146
  start_time = time.time()
 
147
  try:
148
  # Load embedding provider from config, or fallback to 'openai'
149
  embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
150
  logging.debug(f"Using embedding provider: {embedding_provider}")
151
 
152
- # Process keywords if provided
153
- keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else []
154
- logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}")
155
 
156
- relevant_ids = {}
 
 
 
157
 
158
- # Fetch relevant IDs based on keywords if keywords are provided
159
- if keyword_list:
160
  try:
161
  for db_type in database_types:
162
  if db_type == "Media DB":
163
- relevant_media_ids = fetch_relevant_media_ids(keyword_list)
164
- relevant_ids[db_type] = relevant_media_ids
165
  elif db_type == "RAG Chat":
166
- conversations, total_pages, total_count = search_conversations_by_keywords(keywords=keyword_list)
167
- relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
168
  elif db_type == "RAG Notes":
169
- notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
170
- relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
171
  elif db_type == "Character Chat":
172
- relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
173
  elif db_type == "Character Cards":
174
- relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
175
  else:
176
  logging.error(f"Unsupported database type: {db_type}")
177
 
178
  logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
179
  except Exception as e:
180
  logging.error(f"Error fetching relevant IDs: {str(e)}")
 
181
  else:
182
- relevant_ids = None
183
-
184
- # Prepare relevant IDs for each database type
185
- relevant_ids_dict = {}
186
- if relevant_ids:
187
- for db_type in database_types:
188
- if db_type in relevant_ids and relevant_ids[db_type]:
189
- relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
190
- else:
191
- relevant_ids_dict[db_type] = None
192
- else:
193
- relevant_ids_dict = {db_type: None for db_type in database_types}
194
 
195
- # Perform vector search for all selected databases
196
  vector_results = []
197
  for db_type in database_types:
198
  try:
199
- db_relevant_ids = relevant_ids_dict.get(db_type)
200
  results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
201
  vector_results.extend(results)
202
  logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
@@ -217,7 +216,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
217
  fts_results = []
218
  for db_type in database_types:
219
  try:
220
- db_relevant_ids = relevant_ids.get(db_type) if relevant_ids else None
221
  db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
222
  fts_results.extend(db_results)
223
  logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
@@ -233,12 +232,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
233
  # Combine results
234
  all_results = vector_results + fts_results
235
 
 
 
 
236
  if apply_re_ranking and all_results:
237
  logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
238
- # FIXME - add option to use re-ranking at call time
239
- # FIXME - specify model + add param to modify at call time
240
- # FIXME - add option to set a custom top X results
241
- # You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
242
  if all_results:
243
  ranker = Ranker()
244
 
@@ -273,7 +272,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
273
  "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
274
  }
275
 
276
- # Metrics
277
  pipeline_duration = time.time() - start_time
278
  log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
279
  log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
@@ -284,7 +283,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
284
  }
285
 
286
  except Exception as e:
287
- # Metrics
288
  log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
289
  logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
290
  logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
 
128
 
129
  # RAG Search with keyword filtering
130
  # FIXME - Update each called function to support modifiable top-k results
131
+ def enhanced_rag_pipeline(
132
+ query: str,
133
+ api_choice: str,
134
+ keywords: Optional[str] = None,
135
+ fts_top_k: int = 10,
136
+ apply_re_ranking: bool = True,
137
+ database_types: List[str] = ["Media DB"]
138
+ ) -> Dict[str, Any]:
139
  """
140
  Perform full text search across specified database type.
141
 
142
  Args:
143
  query: Search query string
144
  api_choice: API to use for generating the response
 
145
  keywords: Optional list of media IDs to filter results
146
+ fts_top_k: Maximum number of results to return
147
+ apply_re_ranking: Whether to apply re-ranking to results
148
+ database_types: Type of database to search
149
 
150
  Returns:
151
  Dictionary containing search results with content
152
  """
153
  log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
154
  start_time = time.time()
155
+
156
  try:
157
  # Load embedding provider from config, or fallback to 'openai'
158
  embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
159
  logging.debug(f"Using embedding provider: {embedding_provider}")
160
 
161
+ # Initialize relevant IDs dictionary
162
+ relevant_ids: Dict[str, Optional[List[str]]] = {}
 
163
 
164
+ # Process keywords if provided
165
+ if keywords:
166
+ keyword_list = [k.strip().lower() for k in keywords.split(',')]
167
+ logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
168
 
 
 
169
  try:
170
  for db_type in database_types:
171
  if db_type == "Media DB":
172
+ media_ids = fetch_relevant_media_ids(keyword_list)
173
+ relevant_ids[db_type] = [str(id_) for id_ in media_ids]
174
  elif db_type == "RAG Chat":
175
+ conversations, _, _ = search_conversations_by_keywords(keywords=keyword_list)
176
+ relevant_ids[db_type] = [str(conv['conversation_id']) for conv in conversations]
177
  elif db_type == "RAG Notes":
178
+ notes, _, _ = get_notes_by_keywords(keyword_list)
179
+ relevant_ids[db_type] = [str(note_id) for note_id, _, _, _ in notes]
180
  elif db_type == "Character Chat":
181
+ relevant_ids[db_type] = [str(id_) for id_ in fetch_keywords_for_chats(keyword_list)]
182
  elif db_type == "Character Cards":
183
+ relevant_ids[db_type] = [str(id_) for id_ in fetch_character_ids_by_keywords(keyword_list)]
184
  else:
185
  logging.error(f"Unsupported database type: {db_type}")
186
 
187
  logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
188
  except Exception as e:
189
  logging.error(f"Error fetching relevant IDs: {str(e)}")
190
+ relevant_ids = {db_type: None for db_type in database_types}
191
  else:
192
+ relevant_ids = {db_type: None for db_type in database_types}
 
 
 
 
 
 
 
 
 
 
 
193
 
194
+ # Perform vector search
195
  vector_results = []
196
  for db_type in database_types:
197
  try:
198
+ db_relevant_ids = relevant_ids.get(db_type)
199
  results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
200
  vector_results.extend(results)
201
  logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
 
216
  fts_results = []
217
  for db_type in database_types:
218
  try:
219
+ db_relevant_ids = relevant_ids.get(db_type)
220
  db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
221
  fts_results.extend(db_results)
222
  logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
 
232
  # Combine results
233
  all_results = vector_results + fts_results
234
 
235
+ # FIXME - specify model + add param to modify at call time
236
+ # You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
237
+ # Apply re-ranking if enabled and results exist
238
  if apply_re_ranking and all_results:
239
  logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
240
+
 
 
 
241
  if all_results:
242
  ranker = Ranker()
243
 
 
272
  "context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
273
  }
274
 
275
+ # Log metrics
276
  pipeline_duration = time.time() - start_time
277
  log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
278
  log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
 
283
  }
284
 
285
  except Exception as e:
 
286
  log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
287
  logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
288
  logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")