Spaces:
Running
Running
File size: 5,910 Bytes
1be405f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# RAG_QA_Chat.py
# Description: Functions supporting the RAG QA Chat functionality
#
# Imports
#
#
# External Imports
import json
import logging
import tempfile
import time
from typing import List, Tuple, IO, Union
#
# Local Imports
from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content
from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
#
########################################################################################################################
#
# Functions:
def rag_qa_chat(message: str, history: List[Tuple[str, str]], context: Union[str, IO[str]], api_choice: str) -> Tuple[List[Tuple[str, str]], str]:
log_counter("rag_qa_chat_attempt", labels={"api_choice": api_choice})
start_time = time.time()
try:
# Prepare the context based on the selected source
if hasattr(context, 'read'):
# Handle uploaded file
context_text = context.read()
if isinstance(context_text, bytes):
context_text = context_text.decode('utf-8')
log_counter("rag_qa_chat_uploaded_file")
elif isinstance(context, str) and context.startswith("media_id:"):
# Handle existing file or search result
media_id = int(context.split(":")[1])
context_text = get_media_content(media_id)
log_counter("rag_qa_chat_existing_media", labels={"media_id": media_id})
else:
context_text = str(context)
log_counter("rag_qa_chat_string_context")
# Prepare the full context including chat history
full_context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
full_context += f"\n\nContext: {context_text}\n\nHuman: {message}\nAI:"
# Generate response using the selected API
response = generate_answer(api_choice, full_context, message)
# Update history
history.append((message, response))
chat_duration = time.time() - start_time
log_histogram("rag_qa_chat_duration", chat_duration, labels={"api_choice": api_choice})
log_counter("rag_qa_chat_success", labels={"api_choice": api_choice})
return history, ""
except DatabaseError as e:
log_counter("rag_qa_chat_database_error", labels={"error": str(e)})
logging.error(f"Database error in rag_qa_chat: {str(e)}")
return history, f"An error occurred while accessing the database: {str(e)}"
except Exception as e:
log_counter("rag_qa_chat_unexpected_error", labels={"error": str(e)})
logging.error(f"Unexpected error in rag_qa_chat: {str(e)}")
return history, f"An unexpected error occurred: {str(e)}"
def save_chat_history(history: List[Tuple[str, str]]) -> str:
# Save chat history to a file
log_counter("save_chat_history_attempt")
start_time = time.time()
try:
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
json.dump(history, temp_file)
save_duration = time.time() - start_time
log_histogram("save_chat_history_duration", save_duration)
log_counter("save_chat_history_success")
return temp_file.name
except Exception as e:
log_counter("save_chat_history_error", labels={"error": str(e)})
logging.error(f"Error saving chat history: {str(e)}")
raise
def load_chat_history(file: IO[str]) -> List[Tuple[str, str]]:
log_counter("load_chat_history_attempt")
start_time = time.time()
try:
# Load chat history from a file
history = json.load(file)
load_duration = time.time() - start_time
log_histogram("load_chat_history_duration", load_duration)
log_counter("load_chat_history_success")
return history
except Exception as e:
log_counter("load_chat_history_error", labels={"error": str(e)})
logging.error(f"Error loading chat history: {str(e)}")
raise
def search_database(query: str) -> List[Tuple[int, str]]:
try:
log_counter("search_database_attempt")
start_time = time.time()
# Implement database search functionality
results = search_db(query, ["title", "content"], "", page=1, results_per_page=10)
search_duration = time.time() - start_time
log_histogram("search_database_duration", search_duration)
log_counter("search_database_success", labels={"result_count": len(results)})
return [(result['id'], result['title']) for result in results]
except Exception as e:
log_counter("search_database_error", labels={"error": str(e)})
logging.error(f"Error searching database: {str(e)}")
raise
def get_existing_files() -> List[Tuple[int, str]]:
log_counter("get_existing_files_attempt")
start_time = time.time()
try:
# Fetch list of existing files from the database
with db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT id, title FROM Media ORDER BY title")
results = cursor.fetchall()
fetch_duration = time.time() - start_time
log_histogram("get_existing_files_duration", fetch_duration)
log_counter("get_existing_files_success", labels={"file_count": len(results)})
return results
except Exception as e:
log_counter("get_existing_files_error", labels={"error": str(e)})
logging.error(f"Error fetching existing files: {str(e)}")
raise
#
# End of RAG_QA_Chat.py
########################################################################################################################
|