File size: 5,910 Bytes
1be405f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# RAG_QA_Chat.py
# Description: Functions supporting the RAG QA Chat functionality
#
# Imports
#
#
# External Imports
import json
import logging
import tempfile
import time
from typing import List, Tuple, IO, Union
#
# Local Imports
from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content
from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
#
########################################################################################################################
#
# Functions:

def rag_qa_chat(message: str, history: List[Tuple[str, str]], context: Union[str, IO[str]], api_choice: str) -> Tuple[List[Tuple[str, str]], str]:
    log_counter("rag_qa_chat_attempt", labels={"api_choice": api_choice})
    start_time = time.time()
    try:
        # Prepare the context based on the selected source
        if hasattr(context, 'read'):
            # Handle uploaded file
            context_text = context.read()
            if isinstance(context_text, bytes):
                context_text = context_text.decode('utf-8')
            log_counter("rag_qa_chat_uploaded_file")
        elif isinstance(context, str) and context.startswith("media_id:"):
            # Handle existing file or search result
            media_id = int(context.split(":")[1])
            context_text = get_media_content(media_id)
            log_counter("rag_qa_chat_existing_media", labels={"media_id": media_id})
        else:
            context_text = str(context)
            log_counter("rag_qa_chat_string_context")

        # Prepare the full context including chat history
        full_context = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history])
        full_context += f"\n\nContext: {context_text}\n\nHuman: {message}\nAI:"

        # Generate response using the selected API
        response = generate_answer(api_choice, full_context, message)

        # Update history
        history.append((message, response))

        chat_duration = time.time() - start_time
        log_histogram("rag_qa_chat_duration", chat_duration, labels={"api_choice": api_choice})
        log_counter("rag_qa_chat_success", labels={"api_choice": api_choice})

        return history, ""
    except DatabaseError as e:
        log_counter("rag_qa_chat_database_error", labels={"error": str(e)})
        logging.error(f"Database error in rag_qa_chat: {str(e)}")
        return history, f"An error occurred while accessing the database: {str(e)}"
    except Exception as e:
        log_counter("rag_qa_chat_unexpected_error", labels={"error": str(e)})
        logging.error(f"Unexpected error in rag_qa_chat: {str(e)}")
        return history, f"An unexpected error occurred: {str(e)}"



def save_chat_history(history: List[Tuple[str, str]]) -> str:
    # Save chat history to a file
    log_counter("save_chat_history_attempt")
    start_time = time.time()
    try:
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
            json.dump(history, temp_file)
            save_duration = time.time() - start_time
            log_histogram("save_chat_history_duration", save_duration)
            log_counter("save_chat_history_success")
            return temp_file.name
    except Exception as e:
        log_counter("save_chat_history_error", labels={"error": str(e)})
        logging.error(f"Error saving chat history: {str(e)}")
        raise


def load_chat_history(file: IO[str]) -> List[Tuple[str, str]]:
    log_counter("load_chat_history_attempt")
    start_time = time.time()
    try:
        # Load chat history from a file
        history = json.load(file)
        load_duration = time.time() - start_time
        log_histogram("load_chat_history_duration", load_duration)
        log_counter("load_chat_history_success")
        return history
    except Exception as e:
        log_counter("load_chat_history_error", labels={"error": str(e)})
        logging.error(f"Error loading chat history: {str(e)}")
        raise

def search_database(query: str) -> List[Tuple[int, str]]:
    try:
        log_counter("search_database_attempt")
        start_time = time.time()
        # Implement database search functionality
        results = search_db(query, ["title", "content"], "", page=1, results_per_page=10)
        search_duration = time.time() - start_time
        log_histogram("search_database_duration", search_duration)
        log_counter("search_database_success", labels={"result_count": len(results)})
        return [(result['id'], result['title']) for result in results]
    except Exception as e:
        log_counter("search_database_error", labels={"error": str(e)})
        logging.error(f"Error searching database: {str(e)}")
        raise


def get_existing_files() -> List[Tuple[int, str]]:
    log_counter("get_existing_files_attempt")
    start_time = time.time()
    try:
        # Fetch list of existing files from the database
        with db.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT id, title FROM Media ORDER BY title")
            results = cursor.fetchall()
        fetch_duration = time.time() - start_time
        log_histogram("get_existing_files_duration", fetch_duration)
        log_counter("get_existing_files_success", labels={"file_count": len(results)})
        return results
    except Exception as e:
        log_counter("get_existing_files_error", labels={"error": str(e)})
        logging.error(f"Error fetching existing files: {str(e)}")
        raise

#
# End of RAG_QA_Chat.py
########################################################################################################################