diff --git a/App_Function_Libraries/Audio/Audio_Files.py b/App_Function_Libraries/Audio/Audio_Files.py index 2780806e27e59cdba34be9bd988544e3f2bdb5c7..3c216813539bd44526a4d85cfeda8ad8f7f09926 100644 --- a/App_Function_Libraries/Audio/Audio_Files.py +++ b/App_Function_Libraries/Audio/Audio_Files.py @@ -117,16 +117,15 @@ def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key progress = [] all_transcriptions = [] all_summaries = [] - #v2 + temp_files = [] # Keep track of temporary files + def format_transcription_with_timestamps(segments): if keep_timestamps: formatted_segments = [] for segment in segments: start = segment.get('Time_Start', 0) end = segment.get('Time_End', 0) - text = segment.get('Text', '').strip() # Ensure text is stripped of leading/trailing spaces - - # Add the formatted timestamp and text to the list, followed by a newline + text = segment.get('Text', '').strip() formatted_segments.append(f"[{start:.2f}-{end:.2f}] {text}") # Join the segments with a newline to ensure proper formatting @@ -191,205 +190,64 @@ def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key 'language': chunk_language } - # Process multiple URLs - urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] - - for i, url in enumerate(urls): - update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - - # Download and process audio file - audio_file_path = download_audio_file(url, use_cookies, cookies) - if not os.path.exists(audio_file_path): - update_progress(f"Downloaded file not found: {audio_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(audio_file_path) - update_progress("Audio file downloaded successfully.") - - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file_path) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(reencoded_mp3_path) - - # Convert re-encoded MP3 to WAV - wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - temp_files.append(wav_file_path) - - # Initialize transcription - transcription = "" - - # Transcribe audio - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) - - # Handle segments nested under 'segments' key - if isinstance(segments, dict) and 'segments' in segments: - segments = segments['segments'] - - if isinstance(segments, list): - # Log first 5 segments for debugging - logging.debug(f"Segments before formatting: {segments[:5]}") - transcription = format_transcription_with_timestamps(segments) - logging.debug(f"Formatted transcription (first 500 chars): {transcription[:500]}") - update_progress("Audio transcribed successfully.") - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - continue - - if not transcription.strip(): - update_progress("Transcription is empty.") - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - # Apply chunking - chunked_text = improved_chunking_process(transcription, chunk_options) - - # Summarize - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: - try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) - update_progress("Audio summarized successfully.") - except Exception as e: - logging.error(f"Error during summarization: {str(e)}") - summary = "Summary generation failed" - failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - else: - summary = "No summary available (API not provided)" + # Process URLs if provided + if audio_urls: + urls = [url.strip() for url in audio_urls.split('\n') if url.strip()] + for i, url in enumerate(urls): + try: + update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}") - all_transcriptions.append(transcription) - all_summaries.append(summary) + # Download and process audio file + audio_file_path = download_audio_file(url, use_cookies, cookies) + if not audio_file_path: + raise FileNotFoundError(f"Failed to download audio from URL: {url}") - # Use custom_title if provided, otherwise use the original filename - title = custom_title if custom_title else os.path.basename(wav_file_path) - - # Add to database - add_media_with_keywords( - url=url, - title=title, - media_type='audio', - content=transcription, - keywords=custom_keywords, - prompt=custom_prompt_input, - summary=summary, - transcription_model=whisper_model, - author="Unknown", - ingestion_date=datetime.now().strftime('%Y-%m-%d') - ) - update_progress("Audio file processed and added to database.") - processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - - # Process uploaded file if provided - if audio_file: - url = generate_unique_id() - if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: - update_progress( - f"Uploaded file size exceeds the maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB. Skipping this file.") - else: - try: - # Re-encode MP3 to fix potential issues - reencoded_mp3_path = reencode_mp3(audio_file.name) - if not os.path.exists(reencoded_mp3_path): - update_progress(f"Re-encoded file not found: {reencoded_mp3_path}") - return update_progress("Processing failed: Re-encoded file not found"), "", "" + temp_files.append(audio_file_path) + # Process the audio file + reencoded_mp3_path = reencode_mp3(audio_file_path) temp_files.append(reencoded_mp3_path) - # Convert re-encoded MP3 to WAV wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) - if not os.path.exists(wav_file_path): - update_progress(f"Converted WAV file not found: {wav_file_path}") - return update_progress("Processing failed: Converted WAV file not found"), "", "" - temp_files.append(wav_file_path) - # Initialize transcription - transcription = "" - - if diarize: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True) - else: - segments = speech_to_text(wav_file_path, whisper_model=whisper_model) + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) - # Handle segments nested under 'segments' key + # Handle segments format if isinstance(segments, dict) and 'segments' in segments: segments = segments['segments'] - if isinstance(segments, list): - transcription = format_transcription_with_timestamps(segments) - else: - update_progress("Unexpected segments format received from speech_to_text.") - logging.error(f"Unexpected segments format: {segments}") + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") - chunked_text = improved_chunking_process(transcription, chunk_options) + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") - logging.debug(f"Audio Transcription API Name: {api_name}") - if api_name: + # Initialize summary with default value + summary = "No summary available" + + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": try: - summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result update_progress("Audio summarized successfully.") except Exception as e: - logging.error(f"Error during summarization: {str(e)}") + logging.error(f"Summarization failed: {str(e)}") summary = "Summary generation failed" - else: - summary = "No summary available (API not provided)" + # Add to results all_transcriptions.append(transcription) all_summaries.append(summary) - # Use custom_title if provided, otherwise use the original filename + # Add to database title = custom_title if custom_title else os.path.basename(wav_file_path) - add_media_with_keywords( - url="Uploaded File", + url=url, title=title, media_type='audio', content=transcription, @@ -400,65 +258,112 @@ def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key author="Unknown", ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - update_progress("Uploaded file processed and added to database.") + processed_count += 1 - log_counter( - metric_name="audio_files_processed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) + update_progress(f"Successfully processed URL {i + 1}") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + except Exception as e: - update_progress(f"Error processing uploaded file: {str(e)}") - logging.error(f"Error processing uploaded file: {str(e)}") failed_count += 1 - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - return update_progress("Processing failed: Error processing uploaded file"), "", "" - # Final cleanup - if not keep_original: - cleanup_files() + update_progress(f"Failed to process URL {i + 1}: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + continue - end_time = time.time() - processing_time = end_time - start_time - # Log processing time - log_histogram( - metric_name="audio_processing_time_seconds", - value=processing_time, - labels={"whisper_model": whisper_model, "api_name": api_name} - ) + # Process uploaded file if provided + if audio_file: + try: + update_progress("Processing uploaded file...") + if os.path.getsize(audio_file.name) > MAX_FILE_SIZE: + raise ValueError(f"File size exceeds maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB") - # Optionally, log total counts - log_counter( - metric_name="total_audio_files_processed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=processed_count - ) + reencoded_mp3_path = reencode_mp3(audio_file.name) + temp_files.append(reencoded_mp3_path) - log_counter( - metric_name="total_audio_files_failed", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=failed_count - ) + wav_file_path = convert_mp3_to_wav(reencoded_mp3_path) + temp_files.append(wav_file_path) + + # Transcribe audio + segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=diarize) + + if isinstance(segments, dict) and 'segments' in segments: + segments = segments['segments'] + + if not isinstance(segments, list): + raise ValueError("Unexpected segments format received from speech_to_text") + transcription = format_transcription_with_timestamps(segments) + if not transcription.strip(): + raise ValueError("Empty transcription generated") + + # Initialize summary with default value + summary = "No summary available" - final_progress = update_progress("All processing complete.") - final_transcriptions = "\n\n".join(all_transcriptions) - final_summaries = "\n\n".join(all_summaries) + # Attempt summarization if API is provided + if api_name and api_name.lower() != "none": + try: + chunked_text = improved_chunking_process(transcription, chunk_options) + summary_result = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key) + if summary_result: + summary = summary_result + update_progress("Audio summarized successfully.") + except Exception as e: + logging.error(f"Summarization failed: {str(e)}") + summary = "Summary generation failed" + + # Add to results + all_transcriptions.append(transcription) + all_summaries.append(summary) + + # Add to database + title = custom_title if custom_title else os.path.basename(wav_file_path) + add_media_with_keywords( + url="Uploaded File", + title=title, + media_type='audio', + content=transcription, + keywords=custom_keywords, + prompt=custom_prompt_input, + summary=summary, + transcription_model=whisper_model, + author="Unknown", + ingestion_date=datetime.now().strftime('%Y-%m-%d') + ) + + processed_count += 1 + update_progress("Successfully processed uploaded file") + log_counter("audio_files_processed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + except Exception as e: + failed_count += 1 + update_progress(f"Failed to process uploaded file: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + + # Cleanup temporary files + if not keep_original: + cleanup_files() + + # Log processing metrics + processing_time = time.time() - start_time + log_histogram("audio_processing_time_seconds", processing_time, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_processed", processed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + log_counter("total_audio_files_failed", failed_count, + {"whisper_model": whisper_model, "api_name": api_name}) + + # Prepare final output + final_progress = update_progress(f"Processing complete. Processed: {processed_count}, Failed: {failed_count}") + final_transcriptions = "\n\n".join(all_transcriptions) if all_transcriptions else "No transcriptions available" + final_summaries = "\n\n".join(all_summaries) if all_summaries else "No summaries available" return final_progress, final_transcriptions, final_summaries except Exception as e: - logging.error(f"Error processing audio files: {str(e)}") - log_counter( - metric_name="audio_files_failed_total", - labels={"whisper_model": whisper_model, "api_name": api_name}, - value=1 - ) - cleanup_files() - return update_progress(f"Processing failed: {str(e)}"), "", "" + logging.error(f"Error in process_audio_files: {str(e)}") + log_counter("audio_files_failed_total", 1, {"whisper_model": whisper_model, "api_name": api_name}) + if not keep_original: + cleanup_files() + return update_progress(f"Processing failed: {str(e)}"), "No transcriptions available", "No summaries available" def format_transcription_with_timestamps(segments, keep_timestamps): diff --git a/App_Function_Libraries/Audio/Audio_Transcription_Lib.py b/App_Function_Libraries/Audio/Audio_Transcription_Lib.py index 1f8053cbe70eed21a41460dfde8a1ae0b237d612..0542832c10cce5b3a9f478db65bf6692734e11e1 100644 --- a/App_Function_Libraries/Audio/Audio_Transcription_Lib.py +++ b/App_Function_Libraries/Audio/Audio_Transcription_Lib.py @@ -332,4 +332,4 @@ def save_audio_temp(audio_data, sample_rate=16000): # # -####################################################################################################################### \ No newline at end of file +####################################################################################################################### diff --git a/App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/__pycache__/test_chat_API_Calls.cpython-312-pytest-7.2.1.pyc b/App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/__pycache__/test_chat_API_Calls.cpython-312-pytest-7.2.1.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c387f83b40c94c758e8e5cdf609e6d147016455 Binary files /dev/null and b/App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/__pycache__/test_chat_API_Calls.cpython-312-pytest-7.2.1.pyc differ diff --git a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py index 9cbcfacd1e0ec6005fe7aa231cadb28671ab6cad..e8e1c1cb4cf9dee73f9c6d3c4a62674e4b944d4d 100644 --- a/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py +++ b/App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py @@ -24,7 +24,7 @@ from tenacity import ( wait_random_exponential, ) -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # ####################################################################################################################### diff --git a/App_Function_Libraries/Books/Book_Ingestion_Lib.py b/App_Function_Libraries/Books/Book_Ingestion_Lib.py index 66e49d904c65c2839a31d06edf5054d96c8c7fb6..8e55ef6fe2d95cfccf6ea4d8f1ce43c226b7f668 100644 --- a/App_Function_Libraries/Books/Book_Ingestion_Lib.py +++ b/App_Function_Libraries/Books/Book_Ingestion_Lib.py @@ -18,6 +18,9 @@ import tempfile import zipfile from datetime import datetime import logging +import xml.etree.ElementTree as ET +import html2text +import csv # # External Imports import ebooklib @@ -241,109 +244,244 @@ def process_zip_file(zip_file, return "\n".join(results) -def import_file_handler(file, - title, - author, - keywords, - system_prompt, - custom_prompt, - auto_summarize, - api_name, - api_key, - max_chunk_size, - chunk_overlap, - custom_chapter_pattern - ): +def import_html(file_path, title=None, author=None, keywords=None, **kwargs): + """ + Imports an HTML file and converts it to markdown format. + """ try: - log_counter("file_import_attempt", labels={"file_name": file.name}) - - # Handle max_chunk_size - if isinstance(max_chunk_size, str): - max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000 - elif not isinstance(max_chunk_size, int): - max_chunk_size = 4000 # Default value if not a string or int - - # Handle chunk_overlap - if isinstance(chunk_overlap, str): - chunk_overlap = int(chunk_overlap) if chunk_overlap.strip() else 0 - elif not isinstance(chunk_overlap, int): - chunk_overlap = 0 # Default value if not a string or int - - chunk_options = { - 'method': 'chapter', - 'max_size': max_chunk_size, - 'overlap': chunk_overlap, - 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None - } + logging.info(f"Importing HTML file from {file_path}") + h = html2text.HTML2Text() + h.ignore_links = False - if file is None: - log_counter("file_import_error", labels={"error": "No file uploaded"}) - return "No file uploaded." + with open(file_path, 'r', encoding='utf-8') as file: + html_content = file.read() - file_path = file.name - if not os.path.exists(file_path): - log_counter("file_import_error", labels={"error": "File not found", "file_name": file.name}) - return "Uploaded file not found." + markdown_content = h.handle(html_content) - start_time = datetime.now() + # Extract title from HTML if not provided + if not title: + soup = BeautifulSoup(html_content, 'html.parser') + title_tag = soup.find('title') + title = title_tag.string if title_tag else os.path.basename(file_path) - if file_path.lower().endswith('.epub'): - status = import_epub( - file_path, - title, - author, - keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options, - custom_chapter_pattern=custom_chapter_pattern - ) - log_counter("epub_import_success", labels={"file_name": file.name}) - result = f"📚 EPUB Imported Successfully:\n{status}" - elif file.name.lower().endswith('.zip'): - status = process_zip_file( - zip_file=file, - title=title, - author=author, - keywords=keywords, - custom_prompt=custom_prompt, - system_prompt=system_prompt, - summary=None, - auto_summarize=auto_summarize, - api_name=api_name, - api_key=api_key, - chunk_options=chunk_options - ) - log_counter("zip_import_success", labels={"file_name": file.name}) - result = f"📦 ZIP Processed Successfully:\n{status}" - elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')): - file_type = file.name.split('.')[-1].upper() - log_counter("unsupported_file_type", labels={"file_type": file_type}) - result = f"{file_type} file import is not yet supported." - else: - log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) - result = "❌ Unsupported file type. Please upload an `.epub` file or a `.zip` file containing `.epub` files." + return process_markdown_content(markdown_content, file_path, title, author, keywords, **kwargs) - end_time = datetime.now() - processing_time = (end_time - start_time).total_seconds() - log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + except Exception as e: + logging.exception(f"Error importing HTML file: {str(e)}") + raise + + +def import_xml(file_path, title=None, author=None, keywords=None, **kwargs): + """ + Imports an XML file and converts it to markdown format. + """ + try: + logging.info(f"Importing XML file from {file_path}") + tree = ET.parse(file_path) + root = tree.getroot() + + # Convert XML to markdown + markdown_content = xml_to_markdown(root) + + return process_markdown_content(markdown_content, file_path, title, author, keywords, **kwargs) + + except Exception as e: + logging.exception(f"Error importing XML file: {str(e)}") + raise + + +def import_opml(file_path, title=None, author=None, keywords=None, **kwargs): + """ + Imports an OPML file and converts it to markdown format. + """ + try: + logging.info(f"Importing OPML file from {file_path}") + tree = ET.parse(file_path) + root = tree.getroot() + + # Extract title from OPML if not provided + if not title: + title_elem = root.find(".//title") + title = title_elem.text if title_elem is not None else os.path.basename(file_path) + + # Convert OPML to markdown + markdown_content = opml_to_markdown(root) + + return process_markdown_content(markdown_content, file_path, title, author, keywords, **kwargs) + + except Exception as e: + logging.exception(f"Error importing OPML file: {str(e)}") + raise + + +def xml_to_markdown(element, level=0): + """ + Recursively converts XML elements to markdown format. + """ + markdown = "" + + # Add element name as heading + if level > 0: + markdown += f"{'#' * min(level, 6)} {element.tag}\n\n" + + # Add element text if it exists + if element.text and element.text.strip(): + markdown += f"{element.text.strip()}\n\n" + + # Process child elements + for child in element: + markdown += xml_to_markdown(child, level + 1) + + return markdown + +def opml_to_markdown(root): + """ + Converts OPML structure to markdown format. + """ + markdown = "# Table of Contents\n\n" + + def process_outline(outline, level=0): + result = "" + for item in outline.findall("outline"): + text = item.get("text", "") + result += f"{' ' * level}- {text}\n" + result += process_outline(item, level + 1) return result + body = root.find(".//body") + if body is not None: + markdown += process_outline(body) + + return markdown + + +def process_markdown_content(markdown_content, file_path, title, author, keywords, **kwargs): + """ + Processes markdown content and adds it to the database. + """ + info_dict = { + 'title': title or os.path.basename(file_path), + 'uploader': author or "Unknown", + 'ingestion_date': datetime.now().strftime('%Y-%m-%d') + } + + # Create segments (you may want to adjust the chunking method) + segments = [{'Text': markdown_content}] + + # Add to database + result = add_media_to_database( + url=file_path, + info_dict=info_dict, + segments=segments, + summary=kwargs.get('summary', "No summary provided"), + keywords=keywords.split(',') if keywords else [], + custom_prompt_input=kwargs.get('custom_prompt'), + whisper_model="Imported", + media_type="document", + overwrite=False + ) + + return f"Document '{title}' imported successfully. Database result: {result}" + + +def import_file_handler(files, + author, + keywords, + system_prompt, + custom_prompt, + auto_summarize, + api_name, + api_key, + max_chunk_size, + chunk_overlap, + custom_chapter_pattern): + try: + if not files: + return "No files uploaded." + + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] + + results = [] + for file in files: + log_counter("file_import_attempt", labels={"file_name": file.name}) + + # Handle max_chunk_size and chunk_overlap + chunk_size = int(max_chunk_size) if isinstance(max_chunk_size, (str, int)) else 4000 + overlap = int(chunk_overlap) if isinstance(chunk_overlap, (str, int)) else 0 + + chunk_options = { + 'method': 'chapter', + 'max_size': chunk_size, + 'overlap': overlap, + 'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None + } + + file_path = file.name + if not os.path.exists(file_path): + results.append(f"❌ File not found: {file.name}") + continue + + start_time = datetime.now() + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + + if file_path.lower().endswith('.epub'): + status = import_epub( + file_path, + title=title, # Use filename as title + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options, + custom_chapter_pattern=custom_chapter_pattern + ) + log_counter("epub_import_success", labels={"file_name": file.name}) + results.append(f"📚 {file.name}: {status}") + + elif file_path.lower().endswith('.zip'): + status = process_zip_file( + zip_file=file, + title=None, # Let each file use its own name + author=author, + keywords=keywords, + custom_prompt=custom_prompt, + system_prompt=system_prompt, + summary=None, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key, + chunk_options=chunk_options + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}: {status}") + else: + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("file_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) + except ValueError as ve: logging.exception(f"Error parsing input values: {str(ve)}") - log_counter("file_import_error", labels={"error": "Invalid input", "file_name": file.name}) return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers." except Exception as e: logging.exception(f"Error during file import: {str(e)}") - log_counter("file_import_error", labels={"error": str(e), "file_name": file.name}) return f"❌ Error during import: {str(e)}" + def read_epub(file_path): """ Reads and extracts text from an EPUB file. @@ -424,9 +562,9 @@ def ingest_text_file(file_path, title=None, author=None, keywords=None): # Add the text file to the database add_media_with_keywords( - url=file_path, + url="its_a_book", title=title, - media_type='document', + media_type='book', content=content, keywords=keywords, prompt='No prompt for text files', diff --git a/App_Function_Libraries/Chat/Chat_Functions.py b/App_Function_Libraries/Chat/Chat_Functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce348668bd95e91dc21483cf092944bc7aa1b35 --- /dev/null +++ b/App_Function_Libraries/Chat/Chat_Functions.py @@ -0,0 +1,453 @@ +# Chat_Functions.py +# Chat functions for interacting with the LLMs as chatbots +import base64 +# Imports +import json +import logging +import os +import re +import sqlite3 +import tempfile +import time +from datetime import datetime +from pathlib import Path +# +# External Imports +# +# Local Imports +from App_Function_Libraries.DB.DB_Manager import start_new_conversation, delete_messages_in_conversation, save_message +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection, get_conversation_name +from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \ + chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface +from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \ + chat_with_kobold, chat_with_llama, chat_with_oobabooga, chat_with_tabbyapi, chat_with_vllm, chat_with_custom_openai +from App_Function_Libraries.DB.SQLite_DB import load_media_content +from App_Function_Libraries.Utils.Utils import generate_unique_filename, load_and_log_configs +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram +# +#################################################################################################### +# +# Functions: + +def approximate_token_count(history): + total_text = '' + for user_msg, bot_msg in history: + if user_msg: + total_text += user_msg + ' ' + if bot_msg: + total_text += bot_msg + ' ' + total_tokens = len(total_text.split()) + return total_tokens + +def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None): + log_counter("chat_api_call_attempt", labels={"api_endpoint": api_endpoint}) + start_time = time.time() + if not api_key: + api_key = None + model = None + try: + logging.info(f"Debug - Chat API Call - API Endpoint: {api_endpoint}") + logging.info(f"Debug - Chat API Call - API Key: {api_key}") + logging.info(f"Debug - Chat chat_api_call - API Endpoint: {api_endpoint}") + if api_endpoint.lower() == 'openai': + response = chat_with_openai(api_key, input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == 'anthropic': + # Retrieve the model from config + loaded_config_data = load_and_log_configs() + model = loaded_config_data['models']['anthropic'] if loaded_config_data else None + response = chat_with_anthropic( + api_key=api_key, + input_data=input_data, + model=model, + custom_prompt_arg=prompt, + system_prompt=system_message + ) + + elif api_endpoint.lower() == "cohere": + response = chat_with_cohere( + api_key, + input_data, + model=model, + custom_prompt_arg=prompt, + system_prompt=system_message, + temp=temp + ) + + elif api_endpoint.lower() == "groq": + response = chat_with_groq(api_key, input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "openrouter": + response = chat_with_openrouter(api_key, input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "deepseek": + response = chat_with_deepseek(api_key, input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "mistral": + response = chat_with_mistral(api_key, input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "llama.cpp": + response = chat_with_llama(input_data, prompt, temp, None, api_key, system_message) + elif api_endpoint.lower() == "kobold": + response = chat_with_kobold(input_data, api_key, prompt, temp, system_message) + + elif api_endpoint.lower() == "ooba": + response = chat_with_oobabooga(input_data, api_key, prompt, temp, system_message) + + elif api_endpoint.lower() == "tabbyapi": + response = chat_with_tabbyapi(input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "vllm": + response = chat_with_vllm(input_data, prompt, system_message) + + elif api_endpoint.lower() == "local-llm": + response = chat_with_local_llm(input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "huggingface": + response = chat_with_huggingface(api_key, input_data, prompt, temp) # , system_message) + + elif api_endpoint.lower() == "ollama": + response = chat_with_ollama(input_data, prompt, None, api_key, temp, system_message) + + elif api_endpoint.lower() == "aphrodite": + response = chat_with_aphrodite(input_data, prompt, temp, system_message) + + elif api_endpoint.lower() == "custom-openai-api": + response = chat_with_custom_openai(api_key, input_data, prompt, temp, system_message) + + else: + raise ValueError(f"Unsupported API endpoint: {api_endpoint}") + + call_duration = time.time() - start_time + log_histogram("chat_api_call_duration", call_duration, labels={"api_endpoint": api_endpoint}) + log_counter("chat_api_call_success", labels={"api_endpoint": api_endpoint}) + return response + + except Exception as e: + log_counter("chat_api_call_error", labels={"api_endpoint": api_endpoint, "error": str(e)}) + logging.error(f"Error in chat function: {str(e)}") + return f"An error occurred: {str(e)}" + + +def chat(message, history, media_content, selected_parts, api_endpoint, api_key, prompt, temperature, + system_message=None): + log_counter("chat_attempt", labels={"api_endpoint": api_endpoint}) + start_time = time.time() + try: + logging.info(f"Debug - Chat Function - Message: {message}") + logging.info(f"Debug - Chat Function - Media Content: {media_content}") + logging.info(f"Debug - Chat Function - Selected Parts: {selected_parts}") + logging.info(f"Debug - Chat Function - API Endpoint: {api_endpoint}") + # logging.info(f"Debug - Chat Function - Prompt: {prompt}") + + # Ensure selected_parts is a list + if not isinstance(selected_parts, (list, tuple)): + selected_parts = [selected_parts] if selected_parts else [] + + # logging.debug(f"Debug - Chat Function - Selected Parts (after check): {selected_parts}") + + # Combine the selected parts of the media content + combined_content = "\n\n".join( + [f"{part.capitalize()}: {media_content.get(part, '')}" for part in selected_parts if part in media_content]) + # Print first 500 chars + # logging.debug(f"Debug - Chat Function - Combined Content: {combined_content[:500]}...") + + # Prepare the input for the API + input_data = f"{combined_content}\n\n" if combined_content else "" + for old_message, old_response in history: + input_data += f"{old_message}\nAssistant: {old_response}\n\n" + input_data += f"{message}\n" + + if system_message: + print(f"System message: {system_message}") + logging.debug(f"Debug - Chat Function - System Message: {system_message}") + temperature = float(temperature) if temperature else 0.7 + temp = temperature + + logging.debug(f"Debug - Chat Function - Temperature: {temperature}") + logging.debug(f"Debug - Chat Function - API Key: {api_key[:10]}") + logging.debug(f"Debug - Chat Function - Prompt: {prompt}") + + # Use the existing API request code based on the selected endpoint + response = chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message) + + chat_duration = time.time() - start_time + log_histogram("chat_duration", chat_duration, labels={"api_endpoint": api_endpoint}) + log_counter("chat_success", labels={"api_endpoint": api_endpoint}) + return response + except Exception as e: + log_counter("chat_error", labels={"api_endpoint": api_endpoint, "error": str(e)}) + logging.error(f"Error in chat function: {str(e)}") + return f"An error occurred: {str(e)}" + + +def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, media_name=None): + log_counter("save_chat_history_to_db_attempt") + start_time = time.time() + logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}") + + try: + # First check if we can access the database + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + except sqlite3.DatabaseError as db_error: + logging.error(f"Database is corrupted or inaccessible: {str(db_error)}") + return conversation_id, "Database error: The database file appears to be corrupted. Please contact support." + + # Now attempt the save + if not conversation_id: + # Only for new conversations, not updates + media_id = None + if isinstance(media_content, dict) and 'content' in media_content: + try: + content = media_content['content'] + content_json = content if isinstance(content, dict) else json.loads(content) + media_id = content_json.get('webpage_url') + media_name = media_name or content_json.get('title', 'Unnamed Media') + except (json.JSONDecodeError, AttributeError) as e: + logging.error(f"Error processing media content: {str(e)}") + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" + else: + media_id = "unknown_media" + media_name = media_name or "Unnamed Media" + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + conversation_title = f"{media_name}_{timestamp}" + conversation_id = start_new_conversation(title=conversation_title, media_id=media_id) + logging.info(f"Created new conversation with ID: {conversation_id}") + + # For both new and existing conversations + try: + delete_messages_in_conversation(conversation_id) + for user_msg, assistant_msg in chatbot: + if user_msg: + save_message(conversation_id, "user", user_msg) + if assistant_msg: + save_message(conversation_id, "assistant", assistant_msg) + except sqlite3.DatabaseError as db_error: + logging.error(f"Database error during message save: {str(db_error)}") + return conversation_id, "Database error: Unable to save messages. Please try again or contact support." + + save_duration = time.time() - start_time + log_histogram("save_chat_history_to_db_duration", save_duration) + log_counter("save_chat_history_to_db_success") + + return conversation_id, "Chat history saved successfully!" + + except Exception as e: + log_counter("save_chat_history_to_db_error", labels={"error": str(e)}) + error_message = f"Failed to save chat history: {str(e)}" + logging.error(error_message, exc_info=True) + return conversation_id, error_message + + +def save_chat_history(history, conversation_id, media_content): + log_counter("save_chat_history_attempt") + start_time = time.time() + try: + content, conversation_name = generate_chat_history_content(history, conversation_id, media_content) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + safe_conversation_name = re.sub(r'[^a-zA-Z0-9_-]', '_', conversation_name) + base_filename = f"{safe_conversation_name}_{timestamp}.json" + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file: + temp_file.write(content) + temp_file_path = temp_file.name + + # Generate a unique filename + unique_filename = generate_unique_filename(os.path.dirname(temp_file_path), base_filename) + final_path = os.path.join(os.path.dirname(temp_file_path), unique_filename) + + # Rename the temporary file to the unique filename + os.rename(temp_file_path, final_path) + + save_duration = time.time() - start_time + log_histogram("save_chat_history_duration", save_duration) + log_counter("save_chat_history_success") + return final_path + except Exception as e: + log_counter("save_chat_history_error", labels={"error": str(e)}) + logging.error(f"Error saving chat history: {str(e)}") + return None + + +def generate_chat_history_content(history, conversation_id, media_content): + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + conversation_name = get_conversation_name(conversation_id) + + if not conversation_name: + media_name = extract_media_name(media_content) + if media_name: + conversation_name = f"{media_name}-chat" + else: + conversation_name = f"chat-{timestamp}" # Fallback name + + chat_data = { + "conversation_id": conversation_id, + "conversation_name": conversation_name, + "timestamp": timestamp, + "history": [ + { + "role": "user" if i % 2 == 0 else "bot", + "content": msg[0] if isinstance(msg, tuple) else msg + } + for i, msg in enumerate(history) + ] + } + + return json.dumps(chat_data, indent=2), conversation_name + + +def extract_media_name(media_content): + if isinstance(media_content, dict): + content = media_content.get('content', {}) + if isinstance(content, str): + try: + content = json.loads(content) + except json.JSONDecodeError: + logging.warning("Failed to parse media_content JSON string") + return None + + # Try to extract title from the content + if isinstance(content, dict): + return content.get('title') or content.get('name') + + logging.warning(f"Unexpected media_content format: {type(media_content)}") + return None + + +def update_chat_content(selected_item, use_content, use_summary, use_prompt, item_mapping): + log_counter("update_chat_content_attempt") + start_time = time.time() + logging.debug(f"Debug - Update Chat Content - Selected Item: {selected_item}\n") + logging.debug(f"Debug - Update Chat Content - Use Content: {use_content}\n\n\n\n") + logging.debug(f"Debug - Update Chat Content - Use Summary: {use_summary}\n\n") + logging.debug(f"Debug - Update Chat Content - Use Prompt: {use_prompt}\n\n") + logging.debug(f"Debug - Update Chat Content - Item Mapping: {item_mapping}\n\n") + + if selected_item and selected_item in item_mapping: + media_id = item_mapping[selected_item] + content = load_media_content(media_id) + selected_parts = [] + if use_content and "content" in content: + selected_parts.append("content") + if use_summary and "summary" in content: + selected_parts.append("summary") + if use_prompt and "prompt" in content: + selected_parts.append("prompt") + + # Modified debug print + if isinstance(content, dict): + print(f"Debug - Update Chat Content - Content keys: {list(content.keys())}") + for key, value in content.items(): + print(f"Debug - Update Chat Content - {key} (first 500 char): {str(value)[:500]}\n\n\n\n") + else: + print(f"Debug - Update Chat Content - Content(first 500 char): {str(content)[:500]}\n\n\n\n") + + print(f"Debug - Update Chat Content - Selected Parts: {selected_parts}") + update_duration = time.time() - start_time + log_histogram("update_chat_content_duration", update_duration) + log_counter("update_chat_content_success") + return content, selected_parts + else: + log_counter("update_chat_content_error", labels={"error": str("No item selected or item not in mapping")}) + print(f"Debug - Update Chat Content - No item selected or item not in mapping") + return {}, [] + +# +# End of Chat functions +####################################################################################################################### + + +####################################################################################################################### +# +# Character Card Functions + +CHARACTERS_FILE = Path('.', 'Helper_Scripts', 'Character_Cards', 'Characters.json') + + +def save_character(character_data): + log_counter("save_character_attempt") + start_time = time.time() + characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json') + characters_dir = os.path.dirname(characters_file) + + try: + if os.path.exists(characters_file): + with open(characters_file, 'r') as f: + characters = json.load(f) + else: + characters = {} + + char_name = character_data['name'] + + # Save the image separately if it exists + if 'image' in character_data: + img_data = base64.b64decode(character_data['image']) + img_filename = f"{char_name.replace(' ', '_')}.png" + img_path = os.path.join(characters_dir, img_filename) + with open(img_path, 'wb') as f: + f.write(img_data) + character_data['image_path'] = os.path.abspath(img_path) + del character_data['image'] # Remove the base64 image data from the JSON + + characters[char_name] = character_data + + with open(characters_file, 'w') as f: + json.dump(characters, f, indent=2) + + save_duration = time.time() - start_time + log_histogram("save_character_duration", save_duration) + log_counter("save_character_success") + logging.info(f"Character '{char_name}' saved successfully.") + except Exception as e: + log_counter("save_character_error", labels={"error": str(e)}) + logging.error(f"Error saving character: {str(e)}") + + +def load_characters(): + log_counter("load_characters_attempt") + start_time = time.time() + try: + characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json') + if os.path.exists(characters_file): + with open(characters_file, 'r') as f: + characters = json.load(f) + logging.debug(f"Loaded {len(characters)} characters from {characters_file}") + load_duration = time.time() - start_time + log_histogram("load_characters_duration", load_duration) + log_counter("load_characters_success", labels={"character_count": len(characters)}) + return characters + else: + logging.warning(f"Characters file not found: {characters_file}") + return {} + except Exception as e: + log_counter("load_characters_error", labels={"error": str(e)}) + return {} + + + +def get_character_names(): + log_counter("get_character_names_attempt") + start_time = time.time() + try: + characters = load_characters() + names = list(characters.keys()) + get_names_duration = time.time() - start_time + log_histogram("get_character_names_duration", get_names_duration) + log_counter("get_character_names_success", labels={"name_count": len(names)}) + return names + except Exception as e: + log_counter("get_character_names_error", labels={"error": str(e)}) + logging.error(f"Error getting character names: {str(e)}") + return [] + +# +# End of Chat.py +########################################################################################################################## diff --git a/App_Function_Libraries/Chat/__init__.py b/App_Function_Libraries/Chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/App_Function_Libraries/Chunk_Lib.py b/App_Function_Libraries/Chunk_Lib.py index f60bcf2e6f450c46653f428e513a85fd4f4564dd..ef5b13ae634e6a134200c6344734a2a0d396c0c6 100644 --- a/App_Function_Libraries/Chunk_Lib.py +++ b/App_Function_Libraries/Chunk_Lib.py @@ -11,6 +11,7 @@ import json import logging import re from typing import Any, Dict, List, Optional, Tuple +import xml.etree.ElementTree as ET # # Import 3rd party from openai import OpenAI @@ -23,7 +24,6 @@ from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity # # Import Local -from App_Function_Libraries.Tokenization_Methods_Lib import openai_tokenize from App_Function_Libraries.Utils.Utils import load_comprehensive_config # ####################################################################################################################### @@ -106,6 +106,7 @@ def load_document(file_path: str) -> str: def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]: logging.debug("Improved chunking process started...") + logging.debug(f"Received chunk_options: {chunk_options}") # Extract JSON metadata if present json_content = {} @@ -125,49 +126,70 @@ def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) - text = text[len(header_text):].strip() logging.debug(f"Extracted header text: {header_text}") - options = chunk_options.copy() if chunk_options else {} + # Make a copy of chunk_options and ensure values are correct types + options = {} if chunk_options: - options.update(chunk_options) - - chunk_method = options.get('method', 'words') - max_size = options.get('max_size', 2000) - overlap = options.get('overlap', 0) - language = options.get('language', None) + try: + options['method'] = str(chunk_options.get('method', 'words')) + options['max_size'] = int(chunk_options.get('max_size', 2000)) + options['overlap'] = int(chunk_options.get('overlap', 0)) + # Handle language specially - it can be None + lang = chunk_options.get('language') + options['language'] = str(lang) if lang is not None else None + logging.debug(f"Processed options: {options}") + except Exception as e: + logging.error(f"Error processing chunk options: {e}") + raise + else: + options = {'method': 'words', 'max_size': 2000, 'overlap': 0, 'language': None} + logging.debug("Using default options") - if language is None: - language = detect_language(text) + if options.get('language') is None: + detected_lang = detect_language(text) + options['language'] = str(detected_lang) + logging.debug(f"Detected language: {options['language']}") - if chunk_method == 'json': - chunks = chunk_text_by_json(text, max_size=max_size, overlap=overlap) - else: - chunks = chunk_text(text, chunk_method, max_size, overlap, language) + try: + if options['method'] == 'json': + chunks = chunk_text_by_json(text, max_size=options['max_size'], overlap=options['overlap']) + else: + chunks = chunk_text(text, options['method'], options['max_size'], options['overlap'], options['language']) + logging.debug(f"Created {len(chunks)} chunks using method {options['method']}") + except Exception as e: + logging.error(f"Error in chunking process: {e}") + raise chunks_with_metadata = [] total_chunks = len(chunks) - for i, chunk in enumerate(chunks): - metadata = { - 'chunk_index': i + 1, - 'total_chunks': total_chunks, - 'chunk_method': chunk_method, - 'max_size': max_size, - 'overlap': overlap, - 'language': language, - 'relative_position': (i + 1) / total_chunks - } - metadata.update(json_content) # Add the extracted JSON content to metadata - metadata['header_text'] = header_text # Add the header text to metadata - - if chunk_method == 'json': - chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) - else: - chunk_text_content = chunk + try: + for i, chunk in enumerate(chunks): + metadata = { + 'chunk_index': i + 1, + 'total_chunks': total_chunks, + 'chunk_method': options['method'], + 'max_size': options['max_size'], + 'overlap': options['overlap'], + 'language': options['language'], + 'relative_position': float((i + 1) / total_chunks) + } + metadata.update(json_content) + metadata['header_text'] = header_text + + if options['method'] == 'json': + chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False) + else: + chunk_text_content = chunk - chunks_with_metadata.append({ - 'text': chunk_text_content, - 'metadata': metadata - }) + chunks_with_metadata.append({ + 'text': chunk_text_content, + 'metadata': metadata + }) - return chunks_with_metadata + logging.debug(f"Successfully created metadata for all chunks") + return chunks_with_metadata + except Exception as e: + logging.error(f"Error creating chunk metadata: {e}") + raise def multi_level_chunking(text: str, method: str, max_size: int, overlap: int, language: str) -> List[str]: @@ -220,24 +242,35 @@ def determine_chunk_position(relative_position: float) -> str: def chunk_text_by_words(text: str, max_words: int = 300, overlap: int = 0, language: str = None) -> List[str]: logging.debug("chunk_text_by_words...") - if language is None: - language = detect_language(text) - - if language.startswith('zh'): # Chinese - import jieba - words = list(jieba.cut(text)) - elif language == 'ja': # Japanese - import fugashi - tagger = fugashi.Tagger() - words = [word.surface for word in tagger(text)] - else: # Default to simple splitting for other languages - words = text.split() + logging.debug(f"Parameters: max_words={max_words}, overlap={overlap}, language={language}") - chunks = [] - for i in range(0, len(words), max_words - overlap): - chunk = ' '.join(words[i:i + max_words]) - chunks.append(chunk) - return post_process_chunks(chunks) + try: + if language is None: + language = detect_language(text) + logging.debug(f"Detected language: {language}") + + if language.startswith('zh'): # Chinese + import jieba + words = list(jieba.cut(text)) + elif language == 'ja': # Japanese + import fugashi + tagger = fugashi.Tagger() + words = [word.surface for word in tagger(text)] + else: # Default to simple splitting for other languages + words = text.split() + + logging.debug(f"Total words: {len(words)}") + + chunks = [] + for i in range(0, len(words), max_words - overlap): + chunk = ' '.join(words[i:i + max_words]) + chunks.append(chunk) + logging.debug(f"Created chunk {len(chunks)} with {len(chunk.split())} words") + + return post_process_chunks(chunks) + except Exception as e: + logging.error(f"Error in chunk_text_by_words: {e}") + raise def chunk_text_by_sentences(text: str, max_sentences: int = 10, overlap: int = 0, language: str = None) -> List[str]: @@ -338,24 +371,24 @@ def get_chunk_metadata(chunk: str, full_text: str, chunk_type: str = "generic", """ chunk_length = len(chunk) start_index = full_text.find(chunk) - end_index = start_index + chunk_length if start_index != -1 else None + end_index = start_index + chunk_length if start_index != -1 else -1 # Calculate a hash for the chunk chunk_hash = hashlib.md5(chunk.encode()).hexdigest() metadata = { - 'start_index': start_index, - 'end_index': end_index, - 'word_count': len(chunk.split()), - 'char_count': chunk_length, + 'start_index': int(start_index), + 'end_index': int(end_index), + 'word_count': int(len(chunk.split())), + 'char_count': int(chunk_length), 'chunk_type': chunk_type, 'language': language, 'chunk_hash': chunk_hash, - 'relative_position': start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0 + 'relative_position': float(start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0) } if chunk_type == "chapter": - metadata['chapter_number'] = chapter_number + metadata['chapter_number'] = int(chapter_number) if chapter_number is not None else None metadata['chapter_pattern'] = chapter_pattern return metadata @@ -943,6 +976,151 @@ def chunk_ebook_by_chapters(text: str, chunk_options: Dict[str, Any]) -> List[Di # # End of ebook chapter chunking ####################################################################################################################### +# +# XML Chunking + +def extract_xml_structure(element, path=""): + """ + Recursively extract XML structure and content. + Returns a list of (path, text) tuples. + """ + results = [] + current_path = f"{path}/{element.tag}" if path else element.tag + + # Get direct text content + if element.text and element.text.strip(): + results.append((current_path, element.text.strip())) + + # Process attributes if any + if element.attrib: + for key, value in element.attrib.items(): + results.append((f"{current_path}/@{key}", value)) + + # Process child elements + for child in element: + results.extend(extract_xml_structure(child, current_path)) + + return results + + +def chunk_xml(xml_text: str, chunk_options: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Enhanced XML chunking that preserves structure and hierarchy. + Processes XML content into chunks while maintaining structural context. + + Args: + xml_text (str): The XML content as a string + chunk_options (Dict[str, Any]): Configuration options including: + - max_size (int): Maximum chunk size (default: 1000) + - overlap (int): Number of overlapping elements (default: 0) + - method (str): Chunking method (default: 'xml') + - language (str): Content language (default: 'english') + + Returns: + List[Dict[str, Any]]: List of chunks, each containing: + - text: The chunk content + - metadata: Chunk metadata including XML paths and chunking info + """ + logging.debug("Starting XML chunking process...") + + try: + # Parse XML content + root = ET.fromstring(xml_text) + chunks = [] + + # Get chunking parameters with defaults + max_size = chunk_options.get('max_size', 1000) + overlap = chunk_options.get('overlap', 0) + language = chunk_options.get('language', 'english') + + logging.debug(f"Chunking parameters - max_size: {max_size}, overlap: {overlap}, language: {language}") + + # Extract full structure with hierarchy + xml_content = extract_xml_structure(root) + logging.debug(f"Extracted {len(xml_content)} XML elements") + + # Initialize chunking variables + current_chunk = [] + current_size = 0 + chunk_count = 0 + + # Process XML content into chunks + for path, content in xml_content: + # Calculate content size (by words) + content_size = len(content.split()) + + # Check if adding this content would exceed max_size + if current_size + content_size > max_size and current_chunk: + # Create chunk from current content + chunk_text = '\n'.join(f"{p}: {c}" for p, c in current_chunk) + chunk_count += 1 + + # Create chunk with metadata + chunks.append({ + 'text': chunk_text, + 'metadata': { + 'paths': [p for p, _ in current_chunk], + 'chunk_method': 'xml', + 'chunk_index': chunk_count, + 'max_size': max_size, + 'overlap': overlap, + 'language': language, + 'root_tag': root.tag, + 'xml_attributes': dict(root.attrib) + } + }) + + # Handle overlap if specified + if overlap > 0: + # Keep last few items for overlap + overlap_items = current_chunk[-overlap:] + current_chunk = overlap_items + current_size = sum(len(c.split()) for _, c in overlap_items) + logging.debug(f"Created overlap chunk with {len(overlap_items)} items") + else: + current_chunk = [] + current_size = 0 + + # Add current content to chunk + current_chunk.append((path, content)) + current_size += content_size + + # Process final chunk if content remains + if current_chunk: + chunk_text = '\n'.join(f"{p}: {c}" for p, c in current_chunk) + chunk_count += 1 + + chunks.append({ + 'text': chunk_text, + 'metadata': { + 'paths': [p for p, _ in current_chunk], + 'chunk_method': 'xml', + 'chunk_index': chunk_count, + 'max_size': max_size, + 'overlap': overlap, + 'language': language, + 'root_tag': root.tag, + 'xml_attributes': dict(root.attrib) + } + }) + + # Update total chunks count in metadata + for chunk in chunks: + chunk['metadata']['total_chunks'] = chunk_count + + logging.debug(f"XML chunking complete. Created {len(chunks)} chunks") + return chunks + + except ET.ParseError as e: + logging.error(f"XML parsing error: {str(e)}") + raise + except Exception as e: + logging.error(f"Unexpected error during XML chunking: {str(e)}") + raise + +# +# End of XML Chunking +####################################################################################################################### ####################################################################################################################### # diff --git a/App_Function_Libraries/DB/Character_Chat_DB.py b/App_Function_Libraries/DB/Character_Chat_DB.py index f9ee12f84740905e858fc99866f58b32b672a268..fdf04f6bdd8ef5389cbae4dfb73fd8fedfb53386 100644 --- a/App_Function_Libraries/DB/Character_Chat_DB.py +++ b/App_Function_Libraries/DB/Character_Chat_DB.py @@ -1,701 +1,1059 @@ -# character_chat_db.py -# Database functions for managing character cards and chat histories. -# # -# Imports -import configparser -import sqlite3 -import json -import os -import sys -from typing import List, Dict, Optional, Tuple, Any, Union - -from App_Function_Libraries.Utils.Utils import get_database_dir, get_project_relative_path, get_database_path -import logging - -# -####################################################################################################################### -# -# - -def ensure_database_directory(): - os.makedirs(get_database_dir(), exist_ok=True) - -ensure_database_directory() - - -# Construct the path to the config file -config_path = get_project_relative_path('Config_Files/config.txt') - -# Read the config file -config = configparser.ConfigParser() -config.read(config_path) - -# Get the chat db path from the config, or use the default if not specified -chat_DB_PATH = config.get('Database', 'chatDB_path', fallback=get_database_path('chatDB.db')) -print(f"Chat Database path: {chat_DB_PATH}") - -######################################################################################################## -# -# Functions - -# FIXME - Setup properly and test/add documentation for its existence... -def initialize_database(): - """Initialize the SQLite database with required tables and FTS5 virtual tables.""" - conn = None - try: - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - - # Enable foreign key constraints - cursor.execute("PRAGMA foreign_keys = ON;") - - # Create CharacterCards table with V2 fields - cursor.execute(""" - CREATE TABLE IF NOT EXISTS CharacterCards ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - description TEXT, - personality TEXT, - scenario TEXT, - image BLOB, - post_history_instructions TEXT, - first_mes TEXT, - mes_example TEXT, - creator_notes TEXT, - system_prompt TEXT, - alternate_greetings TEXT, - tags TEXT, - creator TEXT, - character_version TEXT, - extensions TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ); - """) - - # Create CharacterChats table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS CharacterChats ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - character_id INTEGER NOT NULL, - conversation_name TEXT, - chat_history TEXT, - is_snapshot BOOLEAN DEFAULT FALSE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (character_id) REFERENCES CharacterCards(id) ON DELETE CASCADE - ); - """) - - # Create FTS5 virtual table for CharacterChats - cursor.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS CharacterChats_fts USING fts5( - conversation_name, - chat_history, - content='CharacterChats', - content_rowid='id' - ); - """) - - # Create triggers to keep FTS5 table in sync with CharacterChats - cursor.executescript(""" - CREATE TRIGGER IF NOT EXISTS CharacterChats_ai AFTER INSERT ON CharacterChats BEGIN - INSERT INTO CharacterChats_fts(rowid, conversation_name, chat_history) - VALUES (new.id, new.conversation_name, new.chat_history); - END; - - CREATE TRIGGER IF NOT EXISTS CharacterChats_ad AFTER DELETE ON CharacterChats BEGIN - DELETE FROM CharacterChats_fts WHERE rowid = old.id; - END; - - CREATE TRIGGER IF NOT EXISTS CharacterChats_au AFTER UPDATE ON CharacterChats BEGIN - UPDATE CharacterChats_fts SET conversation_name = new.conversation_name, chat_history = new.chat_history - WHERE rowid = new.id; - END; - """) - - # Create ChatKeywords table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS ChatKeywords ( - chat_id INTEGER NOT NULL, - keyword TEXT NOT NULL, - FOREIGN KEY (chat_id) REFERENCES CharacterChats(id) ON DELETE CASCADE - ); - """) - - # Create indexes for faster searches - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_chatkeywords_keyword ON ChatKeywords(keyword); - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_chatkeywords_chat_id ON ChatKeywords(chat_id); - """) - - conn.commit() - logging.info("Database initialized successfully.") - except sqlite3.Error as e: - logging.error(f"SQLite error occurred during database initialization: {e}") - if conn: - conn.rollback() - raise - except Exception as e: - logging.error(f"Unexpected error occurred during database initialization: {e}") - if conn: - conn.rollback() - raise - finally: - if conn: - conn.close() - -# Call initialize_database() at the start of your application -def setup_chat_database(): - try: - initialize_database() - except Exception as e: - logging.critical(f"Failed to initialize database: {e}") - sys.exit(1) - -setup_chat_database() - -######################################################################################################## -# -# Character Card handling - -def parse_character_card(card_data: Dict[str, Any]) -> Dict[str, Any]: - """Parse and validate a character card according to V2 specification.""" - v2_data = { - 'name': card_data.get('name', ''), - 'description': card_data.get('description', ''), - 'personality': card_data.get('personality', ''), - 'scenario': card_data.get('scenario', ''), - 'first_mes': card_data.get('first_mes', ''), - 'mes_example': card_data.get('mes_example', ''), - 'creator_notes': card_data.get('creator_notes', ''), - 'system_prompt': card_data.get('system_prompt', ''), - 'post_history_instructions': card_data.get('post_history_instructions', ''), - 'alternate_greetings': json.dumps(card_data.get('alternate_greetings', [])), - 'tags': json.dumps(card_data.get('tags', [])), - 'creator': card_data.get('creator', ''), - 'character_version': card_data.get('character_version', ''), - 'extensions': json.dumps(card_data.get('extensions', {})) - } - - # Handle 'image' separately as it might be binary data - if 'image' in card_data: - v2_data['image'] = card_data['image'] - - return v2_data - - -def add_character_card(card_data: Dict[str, Any]) -> Optional[int]: - """Add or update a character card in the database.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - parsed_card = parse_character_card(card_data) - - # Check if character already exists - cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (parsed_card['name'],)) - row = cursor.fetchone() - - if row: - # Update existing character - character_id = row[0] - update_query = """ - UPDATE CharacterCards - SET description = ?, personality = ?, scenario = ?, image = ?, - post_history_instructions = ?, first_mes = ?, mes_example = ?, - creator_notes = ?, system_prompt = ?, alternate_greetings = ?, - tags = ?, creator = ?, character_version = ?, extensions = ? - WHERE id = ? - """ - cursor.execute(update_query, ( - parsed_card['description'], parsed_card['personality'], parsed_card['scenario'], - parsed_card['image'], parsed_card['post_history_instructions'], parsed_card['first_mes'], - parsed_card['mes_example'], parsed_card['creator_notes'], parsed_card['system_prompt'], - parsed_card['alternate_greetings'], parsed_card['tags'], parsed_card['creator'], - parsed_card['character_version'], parsed_card['extensions'], character_id - )) - else: - # Insert new character - insert_query = """ - INSERT INTO CharacterCards (name, description, personality, scenario, image, - post_history_instructions, first_mes, mes_example, creator_notes, system_prompt, - alternate_greetings, tags, creator, character_version, extensions) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """ - cursor.execute(insert_query, ( - parsed_card['name'], parsed_card['description'], parsed_card['personality'], - parsed_card['scenario'], parsed_card['image'], parsed_card['post_history_instructions'], - parsed_card['first_mes'], parsed_card['mes_example'], parsed_card['creator_notes'], - parsed_card['system_prompt'], parsed_card['alternate_greetings'], parsed_card['tags'], - parsed_card['creator'], parsed_card['character_version'], parsed_card['extensions'] - )) - character_id = cursor.lastrowid - - conn.commit() - return character_id - except sqlite3.IntegrityError as e: - logging.error(f"Error adding character card: {e}") - return None - except Exception as e: - logging.error(f"Unexpected error adding character card: {e}") - return None - finally: - conn.close() - -# def add_character_card(card_data: Dict) -> Optional[int]: -# """Add or update a character card in the database. -# -# Returns the ID of the inserted character or None if failed. -# """ -# conn = sqlite3.connect(chat_DB_PATH) -# cursor = conn.cursor() -# try: -# # Ensure all required fields are present -# required_fields = ['name', 'description', 'personality', 'scenario', 'image', 'post_history_instructions', 'first_message'] -# for field in required_fields: -# if field not in card_data: -# card_data[field] = '' # Assign empty string if field is missing -# -# # Check if character already exists -# cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (card_data['name'],)) -# row = cursor.fetchone() -# -# if row: -# # Update existing character -# character_id = row[0] -# cursor.execute(""" -# UPDATE CharacterCards -# SET description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ? -# WHERE id = ? -# """, ( -# card_data['description'], -# card_data['personality'], -# card_data['scenario'], -# card_data['image'], -# card_data['post_history_instructions'], -# card_data['first_message'], -# character_id -# )) -# else: -# # Insert new character -# cursor.execute(""" -# INSERT INTO CharacterCards (name, description, personality, scenario, image, post_history_instructions, first_message) -# VALUES (?, ?, ?, ?, ?, ?, ?) -# """, ( -# card_data['name'], -# card_data['description'], -# card_data['personality'], -# card_data['scenario'], -# card_data['image'], -# card_data['post_history_instructions'], -# card_data['first_message'] -# )) -# character_id = cursor.lastrowid -# -# conn.commit() -# return cursor.lastrowid -# except sqlite3.IntegrityError as e: -# logging.error(f"Error adding character card: {e}") -# return None -# except Exception as e: -# logging.error(f"Unexpected error adding character card: {e}") -# return None -# finally: -# conn.close() - - -def get_character_cards() -> List[Dict]: - """Retrieve all character cards from the database.""" - logging.debug(f"Fetching characters from DB: {chat_DB_PATH}") - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - cursor.execute("SELECT * FROM CharacterCards") - rows = cursor.fetchall() - columns = [description[0] for description in cursor.description] - conn.close() - characters = [dict(zip(columns, row)) for row in rows] - #logging.debug(f"Characters fetched from DB: {characters}") - return characters - - -def get_character_card_by_id(character_id: Union[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]: - """ - Retrieve a single character card by its ID. - - Args: - character_id: Can be either an integer ID or a dictionary containing character data. - - Returns: - A dictionary containing the character card data, or None if not found. - """ - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - if isinstance(character_id, dict): - # If a dictionary is passed, assume it's already a character card - return character_id - elif isinstance(character_id, int): - # If an integer is passed, fetch the character from the database - cursor.execute("SELECT * FROM CharacterCards WHERE id = ?", (character_id,)) - row = cursor.fetchone() - if row: - columns = [description[0] for description in cursor.description] - return dict(zip(columns, row)) - else: - logging.warning(f"Invalid type for character_id: {type(character_id)}") - return None - except Exception as e: - logging.error(f"Error in get_character_card_by_id: {e}") - return None - finally: - conn.close() - - -def update_character_card(character_id: int, card_data: Dict) -> bool: - """Update an existing character card.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - cursor.execute(""" - UPDATE CharacterCards - SET name = ?, description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ? - WHERE id = ? - """, ( - card_data.get('name'), - card_data.get('description'), - card_data.get('personality'), - card_data.get('scenario'), - card_data.get('image'), - card_data.get('post_history_instructions', ''), - card_data.get('first_message', "Hello! I'm ready to chat."), - character_id - )) - conn.commit() - return cursor.rowcount > 0 - except sqlite3.IntegrityError as e: - logging.error(f"Error updating character card: {e}") - return False - finally: - conn.close() - - -def delete_character_card(character_id: int) -> bool: - """Delete a character card and its associated chats.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - # Delete associated chats first due to foreign key constraint - cursor.execute("DELETE FROM CharacterChats WHERE character_id = ?", (character_id,)) - cursor.execute("DELETE FROM CharacterCards WHERE id = ?", (character_id,)) - conn.commit() - return cursor.rowcount > 0 - except sqlite3.Error as e: - logging.error(f"Error deleting character card: {e}") - return False - finally: - conn.close() - - -def add_character_chat(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]], keywords: Optional[List[str]] = None, is_snapshot: bool = False) -> Optional[int]: - """ - Add a new chat history for a character, optionally associating keywords. - - Args: - character_id (int): The ID of the character. - conversation_name (str): Name of the conversation. - chat_history (List[Tuple[str, str]]): List of (user, bot) message tuples. - keywords (Optional[List[str]]): List of keywords to associate with this chat. - is_snapshot (bool, optional): Whether this chat is a snapshot. - - Returns: - Optional[int]: The ID of the inserted chat or None if failed. - """ - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - chat_history_json = json.dumps(chat_history) - cursor.execute(""" - INSERT INTO CharacterChats (character_id, conversation_name, chat_history, is_snapshot) - VALUES (?, ?, ?, ?) - """, ( - character_id, - conversation_name, - chat_history_json, - is_snapshot - )) - chat_id = cursor.lastrowid - - if keywords: - # Insert keywords into ChatKeywords table - keyword_records = [(chat_id, keyword.strip().lower()) for keyword in keywords] - cursor.executemany(""" - INSERT INTO ChatKeywords (chat_id, keyword) - VALUES (?, ?) - """, keyword_records) - - conn.commit() - return chat_id - except sqlite3.Error as e: - logging.error(f"Error adding character chat: {e}") - return None - finally: - conn.close() - - -def get_character_chats(character_id: Optional[int] = None) -> List[Dict]: - """Retrieve all chats, or chats for a specific character if character_id is provided.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - if character_id is not None: - cursor.execute("SELECT * FROM CharacterChats WHERE character_id = ?", (character_id,)) - else: - cursor.execute("SELECT * FROM CharacterChats") - rows = cursor.fetchall() - columns = [description[0] for description in cursor.description] - conn.close() - return [dict(zip(columns, row)) for row in rows] - - -def get_character_chat_by_id(chat_id: int) -> Optional[Dict]: - """Retrieve a single chat by its ID.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - cursor.execute("SELECT * FROM CharacterChats WHERE id = ?", (chat_id,)) - row = cursor.fetchone() - conn.close() - if row: - columns = [description[0] for description in cursor.description] - chat = dict(zip(columns, row)) - chat['chat_history'] = json.loads(chat['chat_history']) - return chat - return None - - -def search_character_chats(query: str, character_id: Optional[int] = None) -> Tuple[List[Dict], str]: - """ - Search for character chats using FTS5, optionally filtered by character_id. - - Args: - query (str): The search query. - character_id (Optional[int]): The ID of the character to filter chats by. - - Returns: - Tuple[List[Dict], str]: A list of matching chats and a status message. - """ - if not query.strip(): - return [], "Please enter a search query." - - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - if character_id is not None: - # Search with character_id filter - cursor.execute(""" - SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history - FROM CharacterChats_fts - JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id - WHERE CharacterChats_fts MATCH ? AND CharacterChats.character_id = ? - ORDER BY rank - """, (query, character_id)) - else: - # Search without character_id filter - cursor.execute(""" - SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history - FROM CharacterChats_fts - JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id - WHERE CharacterChats_fts MATCH ? - ORDER BY rank - """, (query,)) - - rows = cursor.fetchall() - columns = [description[0] for description in cursor.description] - results = [dict(zip(columns, row)) for row in rows] - - if character_id is not None: - status_message = f"Found {len(results)} chat(s) matching '{query}' for the selected character." - else: - status_message = f"Found {len(results)} chat(s) matching '{query}' across all characters." - - return results, status_message - except Exception as e: - logging.error(f"Error searching chats with FTS5: {e}") - return [], f"Error occurred during search: {e}" - finally: - conn.close() - -def update_character_chat(chat_id: int, chat_history: List[Tuple[str, str]]) -> bool: - """Update an existing chat history.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - chat_history_json = json.dumps(chat_history) - cursor.execute(""" - UPDATE CharacterChats - SET chat_history = ? - WHERE id = ? - """, ( - chat_history_json, - chat_id - )) - conn.commit() - return cursor.rowcount > 0 - except sqlite3.Error as e: - logging.error(f"Error updating character chat: {e}") - return False - finally: - conn.close() - - -def delete_character_chat(chat_id: int) -> bool: - """Delete a specific chat.""" - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - cursor.execute("DELETE FROM CharacterChats WHERE id = ?", (chat_id,)) - conn.commit() - return cursor.rowcount > 0 - except sqlite3.Error as e: - logging.error(f"Error deleting character chat: {e}") - return False - finally: - conn.close() - -def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: - """ - Fetch chat IDs associated with any of the specified keywords. - - Args: - keywords (List[str]): List of keywords to search for. - - Returns: - List[int]: List of chat IDs associated with the keywords. - """ - if not keywords: - return [] - - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - # Construct the WHERE clause to search for each keyword - keyword_clauses = " OR ".join(["keyword = ?"] * len(keywords)) - sql_query = f"SELECT DISTINCT chat_id FROM ChatKeywords WHERE {keyword_clauses}" - cursor.execute(sql_query, keywords) - rows = cursor.fetchall() - chat_ids = [row[0] for row in rows] - return chat_ids - except Exception as e: - logging.error(f"Error in fetch_keywords_for_chats: {e}") - return [] - finally: - conn.close() - -def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]: - """Save chat history to the CharacterChats table. - - Returns the ID of the inserted chat or None if failed. - """ - return add_character_chat(character_id, conversation_name, chat_history) - -def migrate_chat_to_media_db(): - pass - - -def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]: - """ - Perform a full-text search on specified fields with optional filtering and pagination. - - Args: - query (str): The search query. - fields (List[str]): List of fields to search in. - where_clause (str, optional): Additional SQL WHERE clause to filter results. - page (int, optional): Page number for pagination. - results_per_page (int, optional): Number of results per page. - - Returns: - List[Dict[str, Any]]: List of matching chat records with content and metadata. - """ - if not query.strip(): - return [] - - conn = sqlite3.connect(chat_DB_PATH) - cursor = conn.cursor() - try: - # Construct the MATCH query for FTS5 - match_query = " AND ".join(fields) + f" MATCH ?" - # Adjust the query with the fields - fts_query = f""" - SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history - FROM CharacterChats_fts - JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id - WHERE {match_query} - """ - if where_clause: - fts_query += f" AND ({where_clause})" - fts_query += " ORDER BY rank LIMIT ? OFFSET ?" - offset = (page - 1) * results_per_page - cursor.execute(fts_query, (query, results_per_page, offset)) - rows = cursor.fetchall() - columns = [description[0] for description in cursor.description] - results = [dict(zip(columns, row)) for row in rows] - return results - except Exception as e: - logging.error(f"Error in search_db: {e}") - return [] - finally: - conn.close() - - -def perform_full_text_search_chat(query: str, relevant_chat_ids: List[int], page: int = 1, results_per_page: int = 5) -> \ -List[Dict[str, Any]]: - """ - Perform a full-text search within the specified chat IDs using FTS5. - - Args: - query (str): The user's query. - relevant_chat_ids (List[int]): List of chat IDs to search within. - page (int): Pagination page number. - results_per_page (int): Number of results per page. - - Returns: - List[Dict[str, Any]]: List of search results with content and metadata. - """ - try: - # Construct a WHERE clause to limit the search to relevant chat IDs - where_clause = " OR ".join([f"media_id = {chat_id}" for chat_id in relevant_chat_ids]) - if not where_clause: - where_clause = "1" # No restriction if no chat IDs - - # Perform full-text search using FTS5 - fts_results = search_db(query, ["content"], where_clause, page=page, results_per_page=results_per_page) - - filtered_fts_results = [ - { - "content": result['content'], - "metadata": {"media_id": result['id']} - } - for result in fts_results - if result['id'] in relevant_chat_ids - ] - return filtered_fts_results - except Exception as e: - logging.error(f"Error in perform_full_text_search_chat: {str(e)}") - return [] - - -def fetch_all_chats() -> List[Dict[str, Any]]: - """ - Fetch all chat messages from the database. - - Returns: - List[Dict[str, Any]]: List of chat messages with relevant metadata. - """ - try: - chats = get_character_chats() # Modify this function to retrieve all chats - return chats - except Exception as e: - logging.error(f"Error fetching all chats: {str(e)}") - return [] - -# -# End of Character_Chat_DB.py -####################################################################################################################### +# character_chat_db.py +# Database functions for managing character cards and chat histories. +# # +# Imports +import configparser +import sqlite3 +import json +import os +import sys +from typing import List, Dict, Optional, Tuple, Any, Union + +from App_Function_Libraries.Utils.Utils import get_database_dir, get_project_relative_path, get_database_path +from Tests.Chat_APIs.Chat_APIs_Integration_test import logging + +# +####################################################################################################################### +# +# + +def ensure_database_directory(): + os.makedirs(get_database_dir(), exist_ok=True) + +ensure_database_directory() + + +# Construct the path to the config file +config_path = get_project_relative_path('Config_Files/config.txt') + +# Read the config file +config = configparser.ConfigParser() +config.read(config_path) + +# Get the chat db path from the config, or use the default if not specified +chat_DB_PATH = config.get('Database', 'chatDB_path', fallback=get_database_path('chatDB.db')) +print(f"Chat Database path: {chat_DB_PATH}") + +######################################################################################################## +# +# Functions + +# FIXME - Setup properly and test/add documentation for its existence... +def initialize_database(): + """Initialize the SQLite database with required tables and FTS5 virtual tables.""" + conn = None + try: + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + + # Enable foreign key constraints + cursor.execute("PRAGMA foreign_keys = ON;") + + # Create CharacterCards table with V2 fields + cursor.execute(""" + CREATE TABLE IF NOT EXISTS CharacterCards ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + description TEXT, + personality TEXT, + scenario TEXT, + image BLOB, + post_history_instructions TEXT, + first_mes TEXT, + mes_example TEXT, + creator_notes TEXT, + system_prompt TEXT, + alternate_greetings TEXT, + tags TEXT, + creator TEXT, + character_version TEXT, + extensions TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + """) + + # Create FTS5 virtual table for CharacterCards + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS CharacterCards_fts USING fts5( + name, + description, + personality, + scenario, + system_prompt, + content='CharacterCards', + content_rowid='id' + ); + """) + + # Create triggers to keep FTS5 table in sync with CharacterCards + cursor.executescript(""" + CREATE TRIGGER IF NOT EXISTS CharacterCards_ai AFTER INSERT ON CharacterCards BEGIN + INSERT INTO CharacterCards_fts( + rowid, + name, + description, + personality, + scenario, + system_prompt + ) VALUES ( + new.id, + new.name, + new.description, + new.personality, + new.scenario, + new.system_prompt + ); + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_ad AFTER DELETE ON CharacterCards BEGIN + DELETE FROM CharacterCards_fts WHERE rowid = old.id; + END; + + CREATE TRIGGER IF NOT EXISTS CharacterCards_au AFTER UPDATE ON CharacterCards BEGIN + UPDATE CharacterCards_fts SET + name = new.name, + description = new.description, + personality = new.personality, + scenario = new.scenario, + system_prompt = new.system_prompt + WHERE rowid = new.id; + END; + """) + + # Create CharacterChats table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS CharacterChats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + character_id INTEGER NOT NULL, + conversation_name TEXT, + chat_history TEXT, + is_snapshot BOOLEAN DEFAULT FALSE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (character_id) REFERENCES CharacterCards(id) ON DELETE CASCADE + ); + """) + + # Create FTS5 virtual table for CharacterChats + cursor.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS CharacterChats_fts USING fts5( + conversation_name, + chat_history, + content='CharacterChats', + content_rowid='id' + ); + """) + + # Create triggers to keep FTS5 table in sync with CharacterChats + cursor.executescript(""" + CREATE TRIGGER IF NOT EXISTS CharacterChats_ai AFTER INSERT ON CharacterChats BEGIN + INSERT INTO CharacterChats_fts(rowid, conversation_name, chat_history) + VALUES (new.id, new.conversation_name, new.chat_history); + END; + + CREATE TRIGGER IF NOT EXISTS CharacterChats_ad AFTER DELETE ON CharacterChats BEGIN + DELETE FROM CharacterChats_fts WHERE rowid = old.id; + END; + + CREATE TRIGGER IF NOT EXISTS CharacterChats_au AFTER UPDATE ON CharacterChats BEGIN + UPDATE CharacterChats_fts SET conversation_name = new.conversation_name, chat_history = new.chat_history + WHERE rowid = new.id; + END; + """) + + # Create ChatKeywords table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS ChatKeywords ( + chat_id INTEGER NOT NULL, + keyword TEXT NOT NULL, + FOREIGN KEY (chat_id) REFERENCES CharacterChats(id) ON DELETE CASCADE + ); + """) + + # Create indexes for faster searches + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_chatkeywords_keyword ON ChatKeywords(keyword); + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_chatkeywords_chat_id ON ChatKeywords(chat_id); + """) + + conn.commit() + logging.info("Database initialized successfully.") + except sqlite3.Error as e: + logging.error(f"SQLite error occurred during database initialization: {e}") + if conn: + conn.rollback() + raise + except Exception as e: + logging.error(f"Unexpected error occurred during database initialization: {e}") + if conn: + conn.rollback() + raise + finally: + if conn: + conn.close() + +# Call initialize_database() at the start of your application +def setup_chat_database(): + try: + initialize_database() + except Exception as e: + logging.critical(f"Failed to initialize database: {e}") + sys.exit(1) + +setup_chat_database() + + +######################################################################################################## +# +# Character Card handling + +def parse_character_card(card_data: Dict[str, Any]) -> Dict[str, Any]: + """Parse and validate a character card according to V2 specification.""" + v2_data = { + 'name': card_data.get('name', ''), + 'description': card_data.get('description', ''), + 'personality': card_data.get('personality', ''), + 'scenario': card_data.get('scenario', ''), + 'first_mes': card_data.get('first_mes', ''), + 'mes_example': card_data.get('mes_example', ''), + 'creator_notes': card_data.get('creator_notes', ''), + 'system_prompt': card_data.get('system_prompt', ''), + 'post_history_instructions': card_data.get('post_history_instructions', ''), + 'alternate_greetings': json.dumps(card_data.get('alternate_greetings', [])), + 'tags': json.dumps(card_data.get('tags', [])), + 'creator': card_data.get('creator', ''), + 'character_version': card_data.get('character_version', ''), + 'extensions': json.dumps(card_data.get('extensions', {})) + } + + # Handle 'image' separately as it might be binary data + if 'image' in card_data: + v2_data['image'] = card_data['image'] + + return v2_data + + +def add_character_card(card_data: Dict[str, Any]) -> Optional[int]: + """Add or update a character card in the database.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + parsed_card = parse_character_card(card_data) + + # Check if character already exists + cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (parsed_card['name'],)) + row = cursor.fetchone() + + if row: + # Update existing character + character_id = row[0] + update_query = """ + UPDATE CharacterCards + SET description = ?, personality = ?, scenario = ?, image = ?, + post_history_instructions = ?, first_mes = ?, mes_example = ?, + creator_notes = ?, system_prompt = ?, alternate_greetings = ?, + tags = ?, creator = ?, character_version = ?, extensions = ? + WHERE id = ? + """ + cursor.execute(update_query, ( + parsed_card['description'], parsed_card['personality'], parsed_card['scenario'], + parsed_card['image'], parsed_card['post_history_instructions'], parsed_card['first_mes'], + parsed_card['mes_example'], parsed_card['creator_notes'], parsed_card['system_prompt'], + parsed_card['alternate_greetings'], parsed_card['tags'], parsed_card['creator'], + parsed_card['character_version'], parsed_card['extensions'], character_id + )) + else: + # Insert new character + insert_query = """ + INSERT INTO CharacterCards (name, description, personality, scenario, image, + post_history_instructions, first_mes, mes_example, creator_notes, system_prompt, + alternate_greetings, tags, creator, character_version, extensions) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + cursor.execute(insert_query, ( + parsed_card['name'], parsed_card['description'], parsed_card['personality'], + parsed_card['scenario'], parsed_card['image'], parsed_card['post_history_instructions'], + parsed_card['first_mes'], parsed_card['mes_example'], parsed_card['creator_notes'], + parsed_card['system_prompt'], parsed_card['alternate_greetings'], parsed_card['tags'], + parsed_card['creator'], parsed_card['character_version'], parsed_card['extensions'] + )) + character_id = cursor.lastrowid + + conn.commit() + return character_id + except sqlite3.IntegrityError as e: + logging.error(f"Error adding character card: {e}") + return None + except Exception as e: + logging.error(f"Unexpected error adding character card: {e}") + return None + finally: + conn.close() + +# def add_character_card(card_data: Dict) -> Optional[int]: +# """Add or update a character card in the database. +# +# Returns the ID of the inserted character or None if failed. +# """ +# conn = sqlite3.connect(chat_DB_PATH) +# cursor = conn.cursor() +# try: +# # Ensure all required fields are present +# required_fields = ['name', 'description', 'personality', 'scenario', 'image', 'post_history_instructions', 'first_message'] +# for field in required_fields: +# if field not in card_data: +# card_data[field] = '' # Assign empty string if field is missing +# +# # Check if character already exists +# cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (card_data['name'],)) +# row = cursor.fetchone() +# +# if row: +# # Update existing character +# character_id = row[0] +# cursor.execute(""" +# UPDATE CharacterCards +# SET description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ? +# WHERE id = ? +# """, ( +# card_data['description'], +# card_data['personality'], +# card_data['scenario'], +# card_data['image'], +# card_data['post_history_instructions'], +# card_data['first_message'], +# character_id +# )) +# else: +# # Insert new character +# cursor.execute(""" +# INSERT INTO CharacterCards (name, description, personality, scenario, image, post_history_instructions, first_message) +# VALUES (?, ?, ?, ?, ?, ?, ?) +# """, ( +# card_data['name'], +# card_data['description'], +# card_data['personality'], +# card_data['scenario'], +# card_data['image'], +# card_data['post_history_instructions'], +# card_data['first_message'] +# )) +# character_id = cursor.lastrowid +# +# conn.commit() +# return cursor.lastrowid +# except sqlite3.IntegrityError as e: +# logging.error(f"Error adding character card: {e}") +# return None +# except Exception as e: +# logging.error(f"Unexpected error adding character card: {e}") +# return None +# finally: +# conn.close() + + +def get_character_cards() -> List[Dict]: + """Retrieve all character cards from the database.""" + logging.debug(f"Fetching characters from DB: {chat_DB_PATH}") + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + cursor.execute("SELECT * FROM CharacterCards") + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + conn.close() + characters = [dict(zip(columns, row)) for row in rows] + #logging.debug(f"Characters fetched from DB: {characters}") + return characters + + +def get_character_card_by_id(character_id: Union[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Retrieve a single character card by its ID. + + Args: + character_id: Can be either an integer ID or a dictionary containing character data. + + Returns: + A dictionary containing the character card data, or None if not found. + """ + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + if isinstance(character_id, dict): + # If a dictionary is passed, assume it's already a character card + return character_id + elif isinstance(character_id, int): + # If an integer is passed, fetch the character from the database + cursor.execute("SELECT * FROM CharacterCards WHERE id = ?", (character_id,)) + row = cursor.fetchone() + if row: + columns = [description[0] for description in cursor.description] + return dict(zip(columns, row)) + else: + logging.warning(f"Invalid type for character_id: {type(character_id)}") + return None + except Exception as e: + logging.error(f"Error in get_character_card_by_id: {e}") + return None + finally: + conn.close() + + +def update_character_card(character_id: int, card_data: Dict) -> bool: + """Update an existing character card.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + cursor.execute(""" + UPDATE CharacterCards + SET name = ?, description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ? + WHERE id = ? + """, ( + card_data.get('name'), + card_data.get('description'), + card_data.get('personality'), + card_data.get('scenario'), + card_data.get('image'), + card_data.get('post_history_instructions', ''), + card_data.get('first_message', "Hello! I'm ready to chat."), + character_id + )) + conn.commit() + return cursor.rowcount > 0 + except sqlite3.IntegrityError as e: + logging.error(f"Error updating character card: {e}") + return False + finally: + conn.close() + + +def delete_character_card(character_id: int) -> bool: + """Delete a character card and its associated chats.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Delete associated chats first due to foreign key constraint + cursor.execute("DELETE FROM CharacterChats WHERE character_id = ?", (character_id,)) + cursor.execute("DELETE FROM CharacterCards WHERE id = ?", (character_id,)) + conn.commit() + return cursor.rowcount > 0 + except sqlite3.Error as e: + logging.error(f"Error deleting character card: {e}") + return False + finally: + conn.close() + + +def add_character_chat(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]], keywords: Optional[List[str]] = None, is_snapshot: bool = False) -> Optional[int]: + """ + Add a new chat history for a character, optionally associating keywords. + + Args: + character_id (int): The ID of the character. + conversation_name (str): Name of the conversation. + chat_history (List[Tuple[str, str]]): List of (user, bot) message tuples. + keywords (Optional[List[str]]): List of keywords to associate with this chat. + is_snapshot (bool, optional): Whether this chat is a snapshot. + + Returns: + Optional[int]: The ID of the inserted chat or None if failed. + """ + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + chat_history_json = json.dumps(chat_history) + cursor.execute(""" + INSERT INTO CharacterChats (character_id, conversation_name, chat_history, is_snapshot) + VALUES (?, ?, ?, ?) + """, ( + character_id, + conversation_name, + chat_history_json, + is_snapshot + )) + chat_id = cursor.lastrowid + + if keywords: + # Insert keywords into ChatKeywords table + keyword_records = [(chat_id, keyword.strip().lower()) for keyword in keywords] + cursor.executemany(""" + INSERT INTO ChatKeywords (chat_id, keyword) + VALUES (?, ?) + """, keyword_records) + + conn.commit() + return chat_id + except sqlite3.Error as e: + logging.error(f"Error adding character chat: {e}") + return None + finally: + conn.close() + + +def get_character_chats(character_id: Optional[int] = None) -> List[Dict]: + """Retrieve all chats, or chats for a specific character if character_id is provided.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + if character_id is not None: + cursor.execute("SELECT * FROM CharacterChats WHERE character_id = ?", (character_id,)) + else: + cursor.execute("SELECT * FROM CharacterChats") + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + conn.close() + return [dict(zip(columns, row)) for row in rows] + + +def get_character_chat_by_id(chat_id: int) -> Optional[Dict]: + """Retrieve a single chat by its ID.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + cursor.execute("SELECT * FROM CharacterChats WHERE id = ?", (chat_id,)) + row = cursor.fetchone() + conn.close() + if row: + columns = [description[0] for description in cursor.description] + chat = dict(zip(columns, row)) + chat['chat_history'] = json.loads(chat['chat_history']) + return chat + return None + + +def search_character_chats(query: str, character_id: Optional[int] = None) -> Tuple[List[Dict], str]: + """ + Search for character chats using FTS5, optionally filtered by character_id. + + Args: + query (str): The search query. + character_id (Optional[int]): The ID of the character to filter chats by. + + Returns: + Tuple[List[Dict], str]: A list of matching chats and a status message. + """ + if not query.strip(): + return [], "Please enter a search query." + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + if character_id is not None: + # Search with character_id filter + cursor.execute(""" + SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history + FROM CharacterChats_fts + JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id + WHERE CharacterChats_fts MATCH ? AND CharacterChats.character_id = ? + ORDER BY rank + """, (query, character_id)) + else: + # Search without character_id filter + cursor.execute(""" + SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history + FROM CharacterChats_fts + JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id + WHERE CharacterChats_fts MATCH ? + ORDER BY rank + """, (query,)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + if character_id is not None: + status_message = f"Found {len(results)} chat(s) matching '{query}' for the selected character." + else: + status_message = f"Found {len(results)} chat(s) matching '{query}' across all characters." + + return results, status_message + except Exception as e: + logging.error(f"Error searching chats with FTS5: {e}") + return [], f"Error occurred during search: {e}" + finally: + conn.close() + +def update_character_chat(chat_id: int, chat_history: List[Tuple[str, str]]) -> bool: + """Update an existing chat history.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + chat_history_json = json.dumps(chat_history) + cursor.execute(""" + UPDATE CharacterChats + SET chat_history = ? + WHERE id = ? + """, ( + chat_history_json, + chat_id + )) + conn.commit() + return cursor.rowcount > 0 + except sqlite3.Error as e: + logging.error(f"Error updating character chat: {e}") + return False + finally: + conn.close() + + +def delete_character_chat(chat_id: int) -> bool: + """Delete a specific chat.""" + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + cursor.execute("DELETE FROM CharacterChats WHERE id = ?", (chat_id,)) + conn.commit() + return cursor.rowcount > 0 + except sqlite3.Error as e: + logging.error(f"Error deleting character chat: {e}") + return False + finally: + conn.close() + + +def fetch_keywords_for_chats(keywords: List[str]) -> List[int]: + """ + Fetch chat IDs associated with any of the specified keywords. + + Args: + keywords (List[str]): List of keywords to search for. + + Returns: + List[int]: List of chat IDs associated with the keywords. + """ + if not keywords: + return [] + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Construct the WHERE clause to search for each keyword + keyword_clauses = " OR ".join(["keyword = ?"] * len(keywords)) + sql_query = f"SELECT DISTINCT chat_id FROM ChatKeywords WHERE {keyword_clauses}" + cursor.execute(sql_query, keywords) + rows = cursor.fetchall() + chat_ids = [row[0] for row in rows] + return chat_ids + except Exception as e: + logging.error(f"Error in fetch_keywords_for_chats: {e}") + return [] + finally: + conn.close() + + +def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]: + """Save chat history to the CharacterChats table. + + Returns the ID of the inserted chat or None if failed. + """ + return add_character_chat(character_id, conversation_name, chat_history) + + +def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]: + """ + Perform a full-text search on specified fields with optional filtering and pagination. + + Args: + query (str): The search query. + fields (List[str]): List of fields to search in. + where_clause (str, optional): Additional SQL WHERE clause to filter results. + page (int, optional): Page number for pagination. + results_per_page (int, optional): Number of results per page. + + Returns: + List[Dict[str, Any]]: List of matching chat records with content and metadata. + """ + if not query.strip(): + return [] + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Construct the MATCH query for FTS5 + match_query = " AND ".join(fields) + f" MATCH ?" + # Adjust the query with the fields + fts_query = f""" + SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history + FROM CharacterChats_fts + JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id + WHERE {match_query} + """ + if where_clause: + fts_query += f" AND ({where_clause})" + fts_query += " ORDER BY rank LIMIT ? OFFSET ?" + offset = (page - 1) * results_per_page + cursor.execute(fts_query, (query, results_per_page, offset)) + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + return results + except Exception as e: + logging.error(f"Error in search_db: {e}") + return [] + finally: + conn.close() + + +def perform_full_text_search_chat(query: str, relevant_chat_ids: List[int], page: int = 1, results_per_page: int = 5) -> \ +List[Dict[str, Any]]: + """ + Perform a full-text search within the specified chat IDs using FTS5. + + Args: + query (str): The user's query. + relevant_chat_ids (List[int]): List of chat IDs to search within. + page (int): Pagination page number. + results_per_page (int): Number of results per page. + + Returns: + List[Dict[str, Any]]: List of search results with content and metadata. + """ + try: + # Construct a WHERE clause to limit the search to relevant chat IDs + where_clause = " OR ".join([f"media_id = {chat_id}" for chat_id in relevant_chat_ids]) + if not where_clause: + where_clause = "1" # No restriction if no chat IDs + + # Perform full-text search using FTS5 + fts_results = search_db(query, ["content"], where_clause, page=page, results_per_page=results_per_page) + + filtered_fts_results = [ + { + "content": result['content'], + "metadata": {"media_id": result['id']} + } + for result in fts_results + if result['id'] in relevant_chat_ids + ] + return filtered_fts_results + except Exception as e: + logging.error(f"Error in perform_full_text_search_chat: {str(e)}") + return [] + + +def fetch_all_chats() -> List[Dict[str, Any]]: + """ + Fetch all chat messages from the database. + + Returns: + List[Dict[str, Any]]: List of chat messages with relevant metadata. + """ + try: + chats = get_character_chats() # Modify this function to retrieve all chats + return chats + except Exception as e: + logging.error(f"Error fetching all chats: {str(e)}") + return [] + + +def search_character_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + # Construct a WHERE clause to limit the search to relevant character IDs + where_clause = "" + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + where_clause = f"CharacterChats.character_id IN ({placeholders})" + + # Perform full-text search using existing search_db function + results = search_db(query, ["conversation_name", "chat_history"], where_clause, results_per_page=fts_top_k) + + # Format results + formatted_results = [] + for r in results: + formatted_results.append({ + "content": r['chat_history'], + "metadata": { + "chat_id": r['id'], + "conversation_name": r['conversation_name'], + "character_id": r['character_id'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_chat: {e}") + return [] + + +def search_character_cards(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the Character Cards database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of character IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + + # Construct the query + sql_query = """ + SELECT CharacterCards.id, CharacterCards.name, CharacterCards.description, CharacterCards.personality, CharacterCards.scenario + FROM CharacterCards_fts + JOIN CharacterCards ON CharacterCards_fts.rowid = CharacterCards.id + WHERE CharacterCards_fts MATCH ? + """ + + params = [query] + + # Add filtering by character IDs if provided + if relevant_media_ids: + placeholders = ','.join(['?'] * len(relevant_media_ids)) + sql_query += f" AND CharacterCards.id IN ({placeholders})" + params.extend(relevant_media_ids) + + sql_query += " LIMIT ?" + params.append(fts_top_k) + + cursor.execute(sql_query, params) + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + + results = [dict(zip(columns, row)) for row in rows] + + # Format results + formatted_results = [] + for r in results: + content = f"Name: {r['name']}\nDescription: {r['description']}\nPersonality: {r['personality']}\nScenario: {r['scenario']}" + formatted_results.append({ + "content": content, + "metadata": { + "character_id": r['id'], + "name": r['name'] + } + }) + + return formatted_results + + except Exception as e: + logging.error(f"Error in search_character_cards: {e}") + return [] + finally: + conn.close() + + +def fetch_character_ids_by_keywords(keywords: List[str]) -> List[int]: + """ + Fetch character IDs associated with any of the specified keywords. + + Args: + keywords (List[str]): List of keywords to search for. + + Returns: + List[int]: List of character IDs associated with the keywords. + """ + if not keywords: + return [] + + conn = sqlite3.connect(chat_DB_PATH) + cursor = conn.cursor() + try: + # Assuming 'tags' column in CharacterCards table stores tags as JSON array + placeholders = ','.join(['?'] * len(keywords)) + sql_query = f""" + SELECT DISTINCT id FROM CharacterCards + WHERE EXISTS ( + SELECT 1 FROM json_each(tags) + WHERE json_each.value IN ({placeholders}) + ) + """ + cursor.execute(sql_query, keywords) + rows = cursor.fetchall() + character_ids = [row[0] for row in rows] + return character_ids + except Exception as e: + logging.error(f"Error in fetch_character_ids_by_keywords: {e}") + return [] + finally: + conn.close() + + +################################################################### +# +# Character Keywords + +def view_char_keywords(): + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT DISTINCT keyword + FROM CharacterCards + CROSS JOIN json_each(tags) + WHERE json_valid(tags) + ORDER BY keyword + """) + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current Character Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + +def add_char_keywords(name: str, keywords: str): + try: + keywords_list = [k.strip() for k in keywords.split(",") if k.strip()] + with sqlite3.connect('character_chat.db') as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT tags FROM CharacterCards WHERE name = ?", + (name,) + ) + result = cursor.fetchone() + if not result: + return "Character not found." + + current_tags = result[0] if result[0] else "[]" + current_keywords = set(current_tags[1:-1].split(',')) if current_tags != "[]" else set() + updated_keywords = current_keywords.union(set(keywords_list)) + + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (str(list(updated_keywords)), name) + ) + conn.commit() + return f"Successfully added keywords to character {name}" + except Exception as e: + return f"Error adding keywords: {str(e)}" + + +def delete_char_keyword(char_name: str, keyword: str) -> str: + """ + Delete a keyword from a character's tags. + + Args: + char_name (str): The name of the character + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # First, check if the character exists + cursor.execute("SELECT tags FROM CharacterCards WHERE name = ?", (char_name,)) + result = cursor.fetchone() + + if not result: + return f"Character '{char_name}' not found." + + # Parse existing tags + current_tags = json.loads(result[0]) if result[0] else [] + + if keyword not in current_tags: + return f"Keyword '{keyword}' not found in character '{char_name}' tags." + + # Remove the keyword + updated_tags = [tag for tag in current_tags if tag != keyword] + + # Update the character's tags + cursor.execute( + "UPDATE CharacterCards SET tags = ? WHERE name = ?", + (json.dumps(updated_tags), char_name) + ) + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted from character '{char_name}'") + return f"Successfully deleted keyword '{keyword}' from character '{char_name}'." + + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_char_keywords_to_csv() -> Tuple[str, str]: + """ + Export all character keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with sqlite3.connect(chat_DB_PATH) as conn: + cursor = conn.cursor() + + # Get all characters and their tags + cursor.execute(""" + SELECT + name, + tags, + (SELECT COUNT(*) FROM CharacterChats WHERE CharacterChats.character_id = CharacterCards.id) as chat_count + FROM CharacterCards + WHERE json_valid(tags) + ORDER BY name + """) + + results = cursor.fetchall() + + # Process the results to create rows for the CSV + csv_rows = [] + for name, tags_json, chat_count in results: + tags = json.loads(tags_json) if tags_json else [] + for tag in tags: + csv_rows.append([ + tag, # keyword + name, # character name + chat_count # number of chats + ]) + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Character Name', 'Number of Chats']) + writer.writerows(csv_rows) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(csv_rows)} character keyword entries to CSV." + logging.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "" + +# +# End of Character chat keyword functions +###################################################### + + +# +# End of Character_Chat_DB.py +####################################################################################################################### diff --git a/App_Function_Libraries/DB/DB_Backups.py b/App_Function_Libraries/DB/DB_Backups.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad4243dfe381e44969a22679ff28aa96d3a227f --- /dev/null +++ b/App_Function_Libraries/DB/DB_Backups.py @@ -0,0 +1,160 @@ +# Backup_Manager.py +# +# Imports: +import os +import shutil +import sqlite3 +from datetime import datetime +import logging +# +# Local Imports: +from App_Function_Libraries.DB.Character_Chat_DB import chat_DB_PATH +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_rag_qa_db_path +from App_Function_Libraries.Utils.Utils import get_project_relative_path +# +# End of Imports +####################################################################################################################### +# +# Functions: + +def init_backup_directory(backup_base_dir: str, db_name: str) -> str: + """Initialize backup directory for a specific database.""" + backup_dir = os.path.join(backup_base_dir, db_name) + os.makedirs(backup_dir, exist_ok=True) + return backup_dir + + +def create_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create a full backup of the database.""" + try: + db_path = os.path.abspath(db_path) + backup_dir = os.path.abspath(backup_dir) + + logging.info(f"Creating backup:") + logging.info(f" DB Path: {db_path}") + logging.info(f" Backup Dir: {backup_dir}") + logging.info(f" DB Name: {db_name}") + + # Create subdirectory based on db_name + specific_backup_dir = os.path.join(backup_dir, db_name) + os.makedirs(specific_backup_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(specific_backup_dir, f"{db_name}_backup_{timestamp}.db") + logging.info(f" Full backup path: {backup_file}") + + # Create a backup using SQLite's backup API + with sqlite3.connect(db_path) as source, \ + sqlite3.connect(backup_file) as target: + source.backup(target) + + logging.info(f"Backup created successfully: {backup_file}") + return f"Backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def create_incremental_backup(db_path: str, backup_dir: str, db_name: str) -> str: + """Create an incremental backup using VACUUM INTO.""" + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = os.path.join(backup_dir, + f"{db_name}_incremental_{timestamp}.sqlib") + + with sqlite3.connect(db_path) as conn: + conn.execute(f"VACUUM INTO '{backup_file}'") + + logging.info(f"Incremental backup created: {backup_file}") + return f"Incremental backup created: {backup_file}" + except Exception as e: + error_msg = f"Failed to create incremental backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def list_backups(backup_dir: str) -> str: + """List all available backups.""" + try: + backups = [f for f in os.listdir(backup_dir) + if f.endswith(('.db', '.sqlib'))] + backups.sort(reverse=True) # Most recent first + return "\n".join(backups) if backups else "No backups found" + except Exception as e: + error_msg = f"Failed to list backups: {str(e)}" + logging.error(error_msg) + return error_msg + + +def restore_single_db_backup(db_path: str, backup_dir: str, db_name: str, backup_name: str) -> str: + """Restore database from a backup file.""" + try: + logging.info(f"Restoring backup: {backup_name}") + backup_path = os.path.join(backup_dir, backup_name) + if not os.path.exists(backup_path): + logging.error(f"Backup file not found: {backup_name}") + return f"Backup file not found: {backup_name}" + + # Create a timestamp for the current db + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + current_backup = os.path.join(backup_dir, + f"{db_name}_pre_restore_{timestamp}.db") + + # Backup current database before restore + logging.info(f"Creating backup of current database: {current_backup}") + shutil.copy2(db_path, current_backup) + + # Restore the backup + logging.info(f"Restoring database from {backup_name}") + shutil.copy2(backup_path, db_path) + + logging.info(f"Database restored from {backup_name}") + return f"Database restored from {backup_name}" + except Exception as e: + error_msg = f"Failed to restore backup: {str(e)}" + logging.error(error_msg) + return error_msg + + +def setup_backup_config(): + """Setup configuration for database backups.""" + backup_base_dir = get_project_relative_path('tldw_DB_Backups') + logging.info(f"Base backup directory: {os.path.abspath(backup_base_dir)}") + + # RAG Chat DB configuration + rag_db_path = get_rag_qa_db_path() + rag_backup_dir = os.path.join(backup_base_dir, 'rag_chat') + os.makedirs(rag_backup_dir, exist_ok=True) + logging.info(f"RAG backup directory: {os.path.abspath(rag_backup_dir)}") + + rag_db_config = { + 'db_path': rag_db_path, + 'backup_dir': rag_backup_dir, # Make sure we use the full path + 'db_name': 'rag_qa' + } + + # Character Chat DB configuration + char_backup_dir = os.path.join(backup_base_dir, 'character_chat') + os.makedirs(char_backup_dir, exist_ok=True) + logging.info(f"Character backup directory: {os.path.abspath(char_backup_dir)}") + + char_db_config = { + 'db_path': chat_DB_PATH, + 'backup_dir': char_backup_dir, # Make sure we use the full path + 'db_name': 'chatDB' + } + + # Media DB configuration (based on your logs) + media_backup_dir = os.path.join(backup_base_dir, 'media') + os.makedirs(media_backup_dir, exist_ok=True) + logging.info(f"Media backup directory: {os.path.abspath(media_backup_dir)}") + + media_db_config = { + 'db_path': os.path.join(os.path.dirname(chat_DB_PATH), 'media_summary.db'), + 'backup_dir': media_backup_dir, + 'db_name': 'media' + } + + return rag_db_config, char_db_config, media_db_config + diff --git a/App_Function_Libraries/DB/DB_Manager.py b/App_Function_Libraries/DB/DB_Manager.py index a11e4d9a3872d1f8ba36e707ca4d5914332c7675..72b736b50bd531aff2b14e400af11f5897da0e7e 100644 --- a/App_Function_Libraries/DB/DB_Manager.py +++ b/App_Function_Libraries/DB/DB_Manager.py @@ -13,11 +13,14 @@ from elasticsearch import Elasticsearch # # Import your existing SQLite functions from App_Function_Libraries.DB.SQLite_DB import DatabaseError +from App_Function_Libraries.DB.Prompts_DB import list_prompts as sqlite_list_prompts, \ + fetch_prompt_details as sqlite_fetch_prompt_details, add_prompt as sqlite_add_prompt, \ + search_prompts as sqlite_search_prompts, add_or_update_prompt as sqlite_add_or_update_prompt, \ + load_prompt_details as sqlite_load_prompt_details, insert_prompt_to_db as sqlite_insert_prompt_to_db, \ + delete_prompt as sqlite_delete_prompt from App_Function_Libraries.DB.SQLite_DB import ( update_media_content as sqlite_update_media_content, - list_prompts as sqlite_list_prompts, search_and_display as sqlite_search_and_display, - fetch_prompt_details as sqlite_fetch_prompt_details, keywords_browser_interface as sqlite_keywords_browser_interface, add_keyword as sqlite_add_keyword, delete_keyword as sqlite_delete_keyword, @@ -25,31 +28,17 @@ from App_Function_Libraries.DB.SQLite_DB import ( ingest_article_to_db as sqlite_ingest_article_to_db, add_media_to_database as sqlite_add_media_to_database, import_obsidian_note_to_db as sqlite_import_obsidian_note_to_db, - add_prompt as sqlite_add_prompt, - delete_chat_message as sqlite_delete_chat_message, - update_chat_message as sqlite_update_chat_message, - add_chat_message as sqlite_add_chat_message, - get_chat_messages as sqlite_get_chat_messages, - search_chat_conversations as sqlite_search_chat_conversations, - create_chat_conversation as sqlite_create_chat_conversation, - save_chat_history_to_database as sqlite_save_chat_history_to_database, view_database as sqlite_view_database, get_transcripts as sqlite_get_transcripts, get_trashed_items as sqlite_get_trashed_items, user_delete_item as sqlite_user_delete_item, empty_trash as sqlite_empty_trash, create_automated_backup as sqlite_create_automated_backup, - add_or_update_prompt as sqlite_add_or_update_prompt, - load_prompt_details as sqlite_load_prompt_details, - load_preset_prompts as sqlite_load_preset_prompts, - insert_prompt_to_db as sqlite_insert_prompt_to_db, - delete_prompt as sqlite_delete_prompt, search_and_display_items as sqlite_search_and_display_items, - get_conversation_name as sqlite_get_conversation_name, add_media_with_keywords as sqlite_add_media_with_keywords, check_media_and_whisper_model as sqlite_check_media_and_whisper_model, \ create_document_version as sqlite_create_document_version, - get_document_version as sqlite_get_document_version, sqlite_search_db, add_media_chunk as sqlite_add_media_chunk, + get_document_version as sqlite_get_document_version, search_media_db as sqlite_search_media_db, add_media_chunk as sqlite_add_media_chunk, sqlite_update_fts_for_media, get_unprocessed_media as sqlite_get_unprocessed_media, fetch_item_details as sqlite_fetch_item_details, \ search_media_database as sqlite_search_media_database, mark_as_trash as sqlite_mark_as_trash, \ get_media_transcripts as sqlite_get_media_transcripts, get_specific_transcript as sqlite_get_specific_transcript, \ @@ -60,23 +49,35 @@ from App_Function_Libraries.DB.SQLite_DB import ( delete_specific_prompt as sqlite_delete_specific_prompt, fetch_keywords_for_media as sqlite_fetch_keywords_for_media, \ update_keywords_for_media as sqlite_update_keywords_for_media, check_media_exists as sqlite_check_media_exists, \ - search_prompts as sqlite_search_prompts, get_media_content as sqlite_get_media_content, \ - get_paginated_files as sqlite_get_paginated_files, get_media_title as sqlite_get_media_title, \ - get_all_content_from_database as sqlite_get_all_content_from_database, - get_next_media_id as sqlite_get_next_media_id, \ - batch_insert_chunks as sqlite_batch_insert_chunks, Database, save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, \ - get_workflow_chat as sqlite_get_workflow_chat, update_media_content_with_version as sqlite_update_media_content_with_version, \ + get_media_content as sqlite_get_media_content, get_paginated_files as sqlite_get_paginated_files, \ + get_media_title as sqlite_get_media_title, get_all_content_from_database as sqlite_get_all_content_from_database, \ + get_next_media_id as sqlite_get_next_media_id, batch_insert_chunks as sqlite_batch_insert_chunks, Database, \ + save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, get_workflow_chat as sqlite_get_workflow_chat, \ + update_media_content_with_version as sqlite_update_media_content_with_version, \ check_existing_media as sqlite_check_existing_media, get_all_document_versions as sqlite_get_all_document_versions, \ fetch_paginated_data as sqlite_fetch_paginated_data, get_latest_transcription as sqlite_get_latest_transcription, \ mark_media_as_processed as sqlite_mark_media_as_processed, ) +from App_Function_Libraries.DB.RAG_QA_Chat_DB import start_new_conversation as sqlite_start_new_conversation, \ + save_message as sqlite_save_message, load_chat_history as sqlite_load_chat_history, \ + get_all_conversations as sqlite_get_all_conversations, get_notes_by_keywords as sqlite_get_notes_by_keywords, \ + get_note_by_id as sqlite_get_note_by_id, update_note as sqlite_update_note, save_notes as sqlite_save_notes, \ + clear_keywords_from_note as sqlite_clear_keywords_from_note, add_keywords_to_note as sqlite_add_keywords_to_note, \ + add_keywords_to_conversation as sqlite_add_keywords_to_conversation, \ + get_keywords_for_note as sqlite_get_keywords_for_note, delete_note as sqlite_delete_note, \ + search_conversations_by_keywords as sqlite_search_conversations_by_keywords, \ + delete_conversation as sqlite_delete_conversation, get_conversation_title as sqlite_get_conversation_title, \ + update_conversation_title as sqlite_update_conversation_title, \ + fetch_all_conversations as sqlite_fetch_all_conversations, fetch_all_notes as sqlite_fetch_all_notes, \ + fetch_conversations_by_ids as sqlite_fetch_conversations_by_ids, fetch_notes_by_ids as sqlite_fetch_notes_by_ids, \ + delete_messages_in_conversation as sqlite_delete_messages_in_conversation, \ + get_conversation_text as sqlite_get_conversation_text, search_notes_titles as sqlite_search_notes_titles from App_Function_Libraries.DB.Character_Chat_DB import ( add_character_card as sqlite_add_character_card, get_character_cards as sqlite_get_character_cards, \ get_character_card_by_id as sqlite_get_character_card_by_id, update_character_card as sqlite_update_character_card, \ delete_character_card as sqlite_delete_character_card, add_character_chat as sqlite_add_character_chat, \ get_character_chats as sqlite_get_character_chats, get_character_chat_by_id as sqlite_get_character_chat_by_id, \ - update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat, \ - migrate_chat_to_media_db as sqlite_migrate_chat_to_media_db, + update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat ) # # Local Imports @@ -214,9 +215,9 @@ print(f"Database path: {db.db_path}") # # DB Search functions -def search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10): if db_type == 'sqlite': - return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page) + return sqlite_search_media_db(search_query, search_fields, keywords, page, results_per_page) elif db_type == 'elasticsearch': # Implement Elasticsearch version when available raise NotImplementedError("Elasticsearch version of search_db not yet implemented") @@ -500,13 +501,6 @@ def load_prompt_details(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def load_preset_prompts(*args, **kwargs): - if db_type == 'sqlite': - return sqlite_load_preset_prompts() - elif db_type == 'elasticsearch': - # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") - def insert_prompt_to_db(*args, **kwargs): if db_type == 'sqlite': return sqlite_insert_prompt_to_db(*args, **kwargs) @@ -539,7 +533,6 @@ def mark_as_trash(media_id: int) -> None: else: raise ValueError(f"Unsupported database type: {db_type}") - def get_latest_transcription(*args, **kwargs): if db_type == 'sqlite': return sqlite_get_latest_transcription(*args, **kwargs) @@ -721,62 +714,132 @@ def fetch_keywords_for_media(*args, **kwargs): # # Chat-related Functions -def delete_chat_message(*args, **kwargs): +def search_notes_titles(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_search_notes_titles(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def save_message(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_message(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def load_chat_history(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_load_chat_history(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def start_new_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_start_new_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_all_conversations(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_all_conversations(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_notes_by_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_notes_by_keywords(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_note_by_id(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_note_by_id(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def add_keywords_to_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") + +def get_keywords_for_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_delete_chat_message(*args, **kwargs) + return sqlite_get_keywords_for_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def update_chat_message(*args, **kwargs): +def delete_note(*args, **kwargs): if db_type == 'sqlite': - return sqlite_update_chat_message(*args, **kwargs) + return sqlite_delete_note(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def add_chat_message(*args, **kwargs): +def search_conversations_by_keywords(*args, **kwargs): if db_type == 'sqlite': - return sqlite_add_chat_message(*args, **kwargs) + return sqlite_search_conversations_by_keywords(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_chat_messages(*args, **kwargs): +def delete_conversation(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_chat_messages(*args, **kwargs) + return sqlite_delete_conversation(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def search_chat_conversations(*args, **kwargs): +def get_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_search_chat_conversations(*args, **kwargs) + return sqlite_get_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def create_chat_conversation(*args, **kwargs): +def update_conversation_title(*args, **kwargs): if db_type == 'sqlite': - return sqlite_create_chat_conversation(*args, **kwargs) + return sqlite_update_conversation_title(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def save_chat_history_to_database(*args, **kwargs): +def fetch_all_conversations(*args, **kwargs): if db_type == 'sqlite': - return sqlite_save_chat_history_to_database(*args, **kwargs) + return sqlite_fetch_all_conversations() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") -def get_conversation_name(*args, **kwargs): +def fetch_all_notes(*args, **kwargs): if db_type == 'sqlite': - return sqlite_get_conversation_name(*args, **kwargs) + return sqlite_fetch_all_notes() elif db_type == 'elasticsearch': # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented") +def delete_messages_in_conversation(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_delete_messages_in_conversation(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of delete_messages_in_conversation not yet implemented") + +def get_conversation_text(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_get_conversation_text(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of get_conversation_text not yet implemented") + # # End of Chat-related Functions ############################################################################################################ @@ -856,12 +919,54 @@ def delete_character_chat(*args, **kwargs): # Implement Elasticsearch version raise NotImplementedError("Elasticsearch version of delete_character_chat not yet implemented") -def migrate_chat_to_media_db(*args, **kwargs): +def update_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_update_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of update_note not yet implemented") + +def save_notes(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_save_notes(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of save_notes not yet implemented") + +def clear_keywords(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords not yet implemented") + +def clear_keywords_from_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_clear_keywords_from_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of clear_keywords_from_note not yet implemented") + +def add_keywords_to_note(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_add_keywords_to_note(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of add_keywords_to_note not yet implemented") + +def fetch_conversations_by_ids(*args, **kwargs): + if db_type == 'sqlite': + return sqlite_fetch_conversations_by_ids(*args, **kwargs) + elif db_type == 'elasticsearch': + # Implement Elasticsearch version + raise NotImplementedError("Elasticsearch version of fetch_conversations_by_ids not yet implemented") + +def fetch_notes_by_ids(*args, **kwargs): if db_type == 'sqlite': - return sqlite_migrate_chat_to_media_db(*args, **kwargs) + return sqlite_fetch_notes_by_ids(*args, **kwargs) elif db_type == 'elasticsearch': # Implement Elasticsearch version - raise NotImplementedError("Elasticsearch version of migrate_chat_to_media_db not yet implemented") + raise NotImplementedError("Elasticsearch version of fetch_notes_by_ids not yet implemented") # # End of Character Chat-related Functions diff --git a/App_Function_Libraries/DB/Prompts_DB.py b/App_Function_Libraries/DB/Prompts_DB.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe9328b0bee161ec158c79266b1a26c8b84b07b --- /dev/null +++ b/App_Function_Libraries/DB/Prompts_DB.py @@ -0,0 +1,626 @@ +# Prompts_DB.py +# Description: Functions to manage the prompts database. +# +# Imports +import sqlite3 +import logging +# +# External Imports +import re +from typing import Tuple +# +# Local Imports +from App_Function_Libraries.Utils.Utils import get_database_path +# +####################################################################################################################### +# +# Functions to manage prompts DB + +def create_prompts_db(): + logging.debug("create_prompts_db: Creating prompts database.") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.executescript(''' + CREATE TABLE IF NOT EXISTS Prompts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + author TEXT, + details TEXT, + system TEXT, + user TEXT + ); + CREATE TABLE IF NOT EXISTS Keywords ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyword TEXT NOT NULL UNIQUE COLLATE NOCASE + ); + CREATE TABLE IF NOT EXISTS PromptKeywords ( + prompt_id INTEGER, + keyword_id INTEGER, + FOREIGN KEY (prompt_id) REFERENCES Prompts (id), + FOREIGN KEY (keyword_id) REFERENCES Keywords (id), + PRIMARY KEY (prompt_id, keyword_id) + ); + CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); + CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); + ''') + +# FIXME - dirty hack that should be removed later... +# Migration function to add the 'author' column to the Prompts table +def add_author_column_to_prompts(): + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + # Check if 'author' column already exists + cursor.execute("PRAGMA table_info(Prompts)") + columns = [col[1] for col in cursor.fetchall()] + + if 'author' not in columns: + # Add the 'author' column + cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') + print("Author column added to Prompts table.") + else: + print("Author column already exists in Prompts table.") + +add_author_column_to_prompts() + +def normalize_keyword(keyword): + return re.sub(r'\s+', ' ', keyword.strip().lower()) + + +# FIXME - update calls to this function to use the new args +def add_prompt(name, author, details, system=None, user=None, keywords=None): + logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") + if not name: + logging.error("add_prompt: A name is required.") + return "A name is required." + + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO Prompts (name, author, details, system, user) + VALUES (?, ?, ?, ?, ?) + ''', (name, author, details, system, user)) + prompt_id = cursor.lastrowid + + if keywords: + normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute(''' + INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) + ''', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute(''' + INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) + ''', (prompt_id, keyword_id)) + return "Prompt added successfully." + except sqlite3.IntegrityError: + return "Prompt with this name already exists." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def fetch_prompt_details(name): + logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name = ? + GROUP BY p.id + ''', (name,)) + return cursor.fetchone() + + +def list_prompts(page=1, per_page=10): + logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of prompts + cursor.execute('SELECT COUNT(*) FROM Prompts') + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): + return add_prompt(title, author, description, system_prompt, user_prompt, keywords) + + +def get_prompt_db_connection(): + prompt_db_path = get_database_path('prompts.db') + return sqlite3.connect(prompt_db_path) + + +def search_prompts(query): + logging.debug(f"search_prompts: Searching prompts with query: {query}") + try: + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords + FROM Prompts p + LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id + LEFT JOIN Keywords k ON pk.keyword_id = k.id + WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? + GROUP BY p.id + ORDER BY p.name + """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) + return cursor.fetchall() + except sqlite3.Error as e: + logging.error(f"Error searching prompts: {e}") + return [] + + +def search_prompts_by_keyword(keyword, page=1, per_page=10): + logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") + normalized_keyword = normalize_keyword(keyword) + offset = (page - 1) * per_page + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(''' + SELECT DISTINCT p.name + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + LIMIT ? OFFSET ? + ''', ('%' + normalized_keyword + '%', per_page, offset)) + prompts = [row[0] for row in cursor.fetchall()] + + # Get total count of matching prompts + cursor.execute(''' + SELECT COUNT(DISTINCT p.id) + FROM Prompts p + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword LIKE ? + ''', ('%' + normalized_keyword + '%',)) + total_count = cursor.fetchone()[0] + + total_pages = (total_count + per_page - 1) // per_page + return prompts, total_pages, page + + +def update_prompt_keywords(prompt_name, new_keywords): + logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) + prompt_id = cursor.fetchone() + if not prompt_id: + return "Prompt not found." + prompt_id = prompt_id[0] + + cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) + + normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] + for keyword in set(normalized_keywords): # Use set to remove duplicates + cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) + cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) + keyword_id = cursor.fetchone()[0] + cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', + (prompt_id, keyword_id)) + + # Remove unused keywords + cursor.execute(''' + DELETE FROM Keywords + WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) + ''') + return "Keywords updated successfully." + except sqlite3.Error as e: + return f"Database error: {e}" + + +def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): + logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") + if not title: + return "Error: Title is required." + + existing_prompt = fetch_prompt_details(title) + if existing_prompt: + # Update existing prompt + result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) + if "successfully" in result: + # Update keywords if the prompt update was successful + keyword_result = update_prompt_keywords(title, keywords or []) + result += f" {keyword_result}" + else: + # Insert new prompt + result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) + + return result + + +def load_prompt_details(selected_prompt): + logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") + if selected_prompt: + details = fetch_prompt_details(selected_prompt) + if details: + return details[0], details[1], details[2], details[3], details[4], details[5] + return "", "", "", "", "", "" + + +def update_prompt_in_db(title, author, description, system_prompt, user_prompt): + logging.debug(f"update_prompt_in_db: Updating prompt: {title}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute( + "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", + (author, description, system_prompt, user_prompt, title) + ) + if cursor.rowcount == 0: + return "No prompt found with the given title." + return "Prompt updated successfully!" + except sqlite3.Error as e: + return f"Error updating prompt: {e}" + + +def delete_prompt(prompt_id): + logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Delete associated keywords + cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) + + # Delete the prompt + cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) + + if cursor.rowcount == 0: + return f"No prompt found with ID {prompt_id}" + else: + conn.commit() + return f"Prompt with ID {prompt_id} has been successfully deleted" + except sqlite3.Error as e: + return f"An error occurred: {e}" + + +def delete_prompt_keyword(keyword: str) -> str: + """ + Delete a keyword and its associations from the prompts database. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + logging.debug(f"delete_prompt_keyword: Deleting keyword: {keyword}") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # First normalize the keyword + normalized_keyword = normalize_keyword(keyword) + + # Get the keyword ID + cursor.execute("SELECT id FROM Keywords WHERE keyword = ?", (normalized_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{keyword}' not found." + + keyword_id = result[0] + + # Delete keyword associations from PromptKeywords + cursor.execute("DELETE FROM PromptKeywords WHERE keyword_id = ?", (keyword_id,)) + + # Delete the keyword itself + cursor.execute("DELETE FROM Keywords WHERE id = ?", (keyword_id,)) + + # Get the number of affected prompts + affected_prompts = cursor.rowcount + + conn.commit() + + logging.info(f"Keyword '{keyword}' deleted successfully") + return f"Successfully deleted keyword '{keyword}' and removed it from {affected_prompts} prompts." + + except sqlite3.Error as e: + error_msg = f"Database error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompt_keywords_to_csv() -> Tuple[str, str]: + """ + Export all prompt keywords to a CSV file with associated metadata. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + from datetime import datetime + + logging.debug("export_prompt_keywords_to_csv: Starting export") + try: + # Create a temporary file with a specific name in the system's temp directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, f'prompt_keywords_export_{timestamp}.csv') + + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + + # Get keywords with related prompt information + query = ''' + SELECT + k.keyword, + GROUP_CONCAT(p.name, ' | ') as prompt_names, + COUNT(DISTINCT p.id) as num_prompts, + GROUP_CONCAT(DISTINCT p.author, ' | ') as authors + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + LEFT JOIN Prompts p ON pk.prompt_id = p.id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + ''' + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow([ + 'Keyword', + 'Associated Prompts', + 'Number of Prompts', + 'Authors' + ]) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # prompt_names (may be None) + row[2], # num_prompts + row[3] if row[3] else '' # authors (may be None) + ]) + + status_msg = f"Successfully exported {len(results)} prompt keywords to CSV." + logging.info(status_msg) + + return status_msg, file_path + + except sqlite3.Error as e: + error_msg = f"Database error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +def view_prompt_keywords() -> str: + """ + View all keywords currently in the prompts database. + + Returns: + str: Markdown formatted string of all keywords + """ + logging.debug("view_prompt_keywords: Retrieving all keywords") + try: + with sqlite3.connect(get_database_path('prompts.db')) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT k.keyword, COUNT(DISTINCT pk.prompt_id) as prompt_count + FROM Keywords k + LEFT JOIN PromptKeywords pk ON k.id = pk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """) + + keywords = cursor.fetchall() + if keywords: + keyword_list = [f"- {k[0]} ({k[1]} prompts)" for k in keywords] + return "### Current Prompt Keywords:\n" + "\n".join(keyword_list) + return "No keywords found." + + except Exception as e: + error_msg = f"Error retrieving keywords: {str(e)}" + logging.error(error_msg) + return error_msg + + +def export_prompts( + export_format='csv', + filter_keywords=None, + include_system=True, + include_user=True, + include_details=True, + include_author=True, + include_keywords=True, + markdown_template=None +) -> Tuple[str, str]: + """ + Export prompts to CSV or Markdown with configurable options. + + Args: + export_format (str): 'csv' or 'markdown' + filter_keywords (List[str], optional): Keywords to filter prompts by + include_system (bool): Include system prompts in export + include_user (bool): Include user prompts in export + include_details (bool): Include prompt details/descriptions + include_author (bool): Include author information + include_keywords (bool): Include associated keywords + markdown_template (str, optional): Template for markdown export + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + import tempfile + import os + import zipfile + from datetime import datetime + + try: + # Get prompts data + with get_prompt_db_connection() as conn: + cursor = conn.cursor() + + # Build query based on included fields + select_fields = ['p.name'] + if include_author: + select_fields.append('p.author') + if include_details: + select_fields.append('p.details') + if include_system: + select_fields.append('p.system') + if include_user: + select_fields.append('p.user') + + query = f""" + SELECT DISTINCT {', '.join(select_fields)} + FROM Prompts p + """ + + # Add keyword filtering if specified + if filter_keywords: + placeholders = ','.join(['?' for _ in filter_keywords]) + query += f""" + JOIN PromptKeywords pk ON p.id = pk.prompt_id + JOIN Keywords k ON pk.keyword_id = k.id + WHERE k.keyword IN ({placeholders}) + """ + + cursor.execute(query, filter_keywords if filter_keywords else ()) + prompts = cursor.fetchall() + + # Get keywords for each prompt if needed + if include_keywords: + prompt_keywords = {} + for prompt in prompts: + cursor.execute(""" + SELECT k.keyword + FROM Keywords k + JOIN PromptKeywords pk ON k.id = pk.keyword_id + JOIN Prompts p ON pk.prompt_id = p.id + WHERE p.name = ? + """, (prompt[0],)) + prompt_keywords[prompt[0]] = [row[0] for row in cursor.fetchall()] + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if export_format == 'csv': + # Export as CSV + temp_file = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.csv') + with open(temp_file, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + + # Write header + header = ['Name'] + if include_author: + header.append('Author') + if include_details: + header.append('Details') + if include_system: + header.append('System Prompt') + if include_user: + header.append('User Prompt') + if include_keywords: + header.append('Keywords') + writer.writerow(header) + + # Write data + for prompt in prompts: + row = list(prompt) + if include_keywords: + row.append(', '.join(prompt_keywords.get(prompt[0], []))) + writer.writerow(row) + + return f"Successfully exported {len(prompts)} prompts to CSV.", temp_file + + else: + # Export as Markdown files in ZIP + temp_dir = tempfile.mkdtemp() + zip_path = os.path.join(tempfile.gettempdir(), f'prompts_export_{timestamp}.zip') + + # Define markdown templates + templates = { + "Basic Template": """# {title} +{author_section} +{details_section} +{system_section} +{user_section} +{keywords_section} +""", + "Detailed Template": """# {title} + +## Author +{author_section} + +## Description +{details_section} + +## System Prompt +{system_section} + +## User Prompt +{user_section} + +## Keywords +{keywords_section} +""" + } + + template = templates.get(markdown_template, markdown_template or templates["Basic Template"]) + + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for prompt in prompts: + # Create markdown content + md_content = template.format( + title=prompt[0], + author_section=f"Author: {prompt[1]}" if include_author else "", + details_section=prompt[2] if include_details else "", + system_section=prompt[3] if include_system else "", + user_section=prompt[4] if include_user else "", + keywords_section=', '.join(prompt_keywords.get(prompt[0], [])) if include_keywords else "" + ) + + # Create safe filename + safe_filename = re.sub(r'[^\w\-_\. ]', '_', prompt[0]) + md_path = os.path.join(temp_dir, f"{safe_filename}.md") + + # Write markdown file + with open(md_path, 'w', encoding='utf-8') as f: + f.write(md_content) + + # Add to ZIP + zipf.write(md_path, os.path.basename(md_path)) + + return f"Successfully exported {len(prompts)} prompts to Markdown files.", zip_path + + except Exception as e: + error_msg = f"Error exporting prompts: {str(e)}" + logging.error(error_msg) + return error_msg, "None" + + +create_prompts_db() + +# +# End of Propmts_DB.py +####################################################################################################################### + diff --git a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py index 6622ac5980bea0731894c257640f442052eb66b3..57f44c91f5ba7e6f8d9f69b3c2ec5d12173daf8b 100644 --- a/App_Function_Libraries/DB/RAG_QA_Chat_DB.py +++ b/App_Function_Libraries/DB/RAG_QA_Chat_DB.py @@ -4,39 +4,37 @@ # Imports import configparser import logging +import os import re import sqlite3 import uuid from contextlib import contextmanager from datetime import datetime - -from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path - +from pathlib import Path +from typing import List, Dict, Any, Tuple, Optional # # External Imports # (No external imports) # # Local Imports -# (No additional local imports) +from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_project_root + # ######################################################################################################################## # # Functions: -# Construct the path to the config file -config_path = get_project_relative_path('Config_Files/config.txt') - -# Read the config file -config = configparser.ConfigParser() -config.read(config_path) - -# Get the SQLite path from the config, or use the default if not specified -if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): - rag_qa_db_path = config.get('Database', 'rag_qa_db_path') -else: - rag_qa_db_path = get_database_path('RAG_QA_Chat.db') - -print(f"RAG QA Chat Database path: {rag_qa_db_path}") +def get_rag_qa_db_path(): + config_path = os.path.join(get_project_root(), 'Config_Files', 'config.txt') + config = configparser.ConfigParser() + config.read(config_path) + if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'): + rag_qa_db_path = config.get('Database', 'rag_qa_db_path') + if not os.path.isabs(rag_qa_db_path): + rag_qa_db_path = get_project_relative_path(rag_qa_db_path) + return rag_qa_db_path + else: + raise ValueError("Database path not found in config file") # Set up logging logging.basicConfig(level=logging.INFO) @@ -58,7 +56,9 @@ CREATE TABLE IF NOT EXISTS conversation_metadata ( conversation_id TEXT PRIMARY KEY, created_at DATETIME NOT NULL, last_updated DATETIME NOT NULL, - title TEXT NOT NULL + title TEXT NOT NULL, + media_id INTEGER, + rating INTEGER CHECK(rating BETWEEN 1 AND 3) ); -- Table for storing keywords @@ -122,19 +122,137 @@ CREATE INDEX IF NOT EXISTS idx_rag_qa_keyword_collections_parent_id ON rag_qa_ke CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_collection_id ON rag_qa_collection_keywords(collection_id); CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_keyword_id ON rag_qa_collection_keywords(keyword_id); --- Full-text search virtual table for chat content -CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5(conversation_id, timestamp, role, content); +-- Full-text search virtual tables +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5( + content, + content='rag_qa_chats', + content_rowid='id' +); + +-- FTS table for conversation metadata +CREATE VIRTUAL TABLE IF NOT EXISTS conversation_metadata_fts USING fts5( + title, + content='conversation_metadata', + content_rowid='rowid' +); + +-- FTS table for keywords +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keywords_fts USING fts5( + keyword, + content='rag_qa_keywords', + content_rowid='id' +); + +-- FTS table for keyword collections +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_keyword_collections_fts USING fts5( + name, + content='rag_qa_keyword_collections', + content_rowid='id' +); + +-- FTS table for notes +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); +-- FTS table for notes (modified to include both title and content) +CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_notes_fts USING fts5( + title, + content, + content='rag_qa_notes', + content_rowid='id' +); --- Trigger to keep the FTS table up to date +-- Triggers for maintaining FTS indexes +-- Triggers for rag_qa_chats CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ai AFTER INSERT ON rag_qa_chats BEGIN - INSERT INTO rag_qa_chats_fts(conversation_id, timestamp, role, content) VALUES (new.conversation_id, new.timestamp, new.role, new.content); + INSERT INTO rag_qa_chats_fts(rowid, content) + VALUES (new.id, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_au AFTER UPDATE ON rag_qa_chats BEGIN + UPDATE rag_qa_chats_fts + SET content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ad AFTER DELETE ON rag_qa_chats BEGIN + DELETE FROM rag_qa_chats_fts WHERE rowid = old.id; +END; + +-- Triggers for conversation_metadata +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ai AFTER INSERT ON conversation_metadata BEGIN + INSERT INTO conversation_metadata_fts(rowid, title) + VALUES (new.rowid, new.title); +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_au AFTER UPDATE ON conversation_metadata BEGIN + UPDATE conversation_metadata_fts + SET title = new.title + WHERE rowid = old.rowid; +END; + +CREATE TRIGGER IF NOT EXISTS conversation_metadata_ad AFTER DELETE ON conversation_metadata BEGIN + DELETE FROM conversation_metadata_fts WHERE rowid = old.rowid; +END; + +-- Triggers for rag_qa_keywords +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ai AFTER INSERT ON rag_qa_keywords BEGIN + INSERT INTO rag_qa_keywords_fts(rowid, keyword) + VALUES (new.id, new.keyword); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_au AFTER UPDATE ON rag_qa_keywords BEGIN + UPDATE rag_qa_keywords_fts + SET keyword = new.keyword + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keywords_ad AFTER DELETE ON rag_qa_keywords BEGIN + DELETE FROM rag_qa_keywords_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_keyword_collections +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ai AFTER INSERT ON rag_qa_keyword_collections BEGIN + INSERT INTO rag_qa_keyword_collections_fts(rowid, name) + VALUES (new.id, new.name); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_au AFTER UPDATE ON rag_qa_keyword_collections BEGIN + UPDATE rag_qa_keyword_collections_fts + SET name = new.name + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_keyword_collections_ad AFTER DELETE ON rag_qa_keyword_collections BEGIN + DELETE FROM rag_qa_keyword_collections_fts WHERE rowid = old.id; +END; + +-- Triggers for rag_qa_notes +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ai AFTER INSERT ON rag_qa_notes BEGIN + INSERT INTO rag_qa_notes_fts(rowid, title, content) + VALUES (new.id, new.title, new.content); +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_au AFTER UPDATE ON rag_qa_notes BEGIN + UPDATE rag_qa_notes_fts + SET title = new.title, + content = new.content + WHERE rowid = old.id; +END; + +CREATE TRIGGER IF NOT EXISTS rag_qa_notes_ad AFTER DELETE ON rag_qa_notes BEGIN + DELETE FROM rag_qa_notes_fts WHERE rowid = old.id; END; ''' # Database connection management @contextmanager def get_db_connection(): - conn = sqlite3.connect(rag_qa_db_path) + db_path = get_rag_qa_db_path() + conn = sqlite3.connect(db_path) try: yield conn finally: @@ -168,10 +286,43 @@ def execute_query(query, params=None, conn=None): conn.commit() return cursor.fetchall() + def create_tables(): + """Create database tables and initialize FTS indexes.""" with get_db_connection() as conn: - conn.executescript(SCHEMA_SQL) - logger.info("All RAG QA Chat tables created successfully") + cursor = conn.cursor() + # Execute the SCHEMA_SQL to create tables and triggers + cursor.executescript(SCHEMA_SQL) + + # Check and populate all FTS tables + fts_tables = [ + ('rag_qa_notes_fts', 'rag_qa_notes', ['title', 'content']), + ('rag_qa_chats_fts', 'rag_qa_chats', ['content']), + ('conversation_metadata_fts', 'conversation_metadata', ['title']), + ('rag_qa_keywords_fts', 'rag_qa_keywords', ['keyword']), + ('rag_qa_keyword_collections_fts', 'rag_qa_keyword_collections', ['name']) + ] + + for fts_table, source_table, columns in fts_tables: + # Check if FTS table needs population + cursor.execute(f"SELECT COUNT(*) FROM {fts_table}") + fts_count = cursor.fetchone()[0] + cursor.execute(f"SELECT COUNT(*) FROM {source_table}") + source_count = cursor.fetchone()[0] + + if fts_count != source_count: + # Repopulate FTS table + logger.info(f"Repopulating {fts_table}") + cursor.execute(f"DELETE FROM {fts_table}") + columns_str = ', '.join(columns) + source_columns = ', '.join([f"id" if source_table != 'conversation_metadata' else "rowid"] + columns) + cursor.execute(f""" + INSERT INTO {fts_table}(rowid, {columns_str}) + SELECT {source_columns} FROM {source_table} + """) + + logger.info("All RAG QA Chat tables and triggers created successfully") + # Initialize the database create_tables() @@ -197,6 +348,7 @@ def validate_keyword(keyword): raise ValueError("Keyword contains invalid characters") return keyword.strip() + def validate_collection_name(name): if not isinstance(name, str): raise ValueError("Collection name must be a string") @@ -208,6 +360,7 @@ def validate_collection_name(name): raise ValueError("Collection name contains invalid characters") return name.strip() + # Core functions def add_keyword(keyword, conn=None): try: @@ -222,6 +375,7 @@ def add_keyword(keyword, conn=None): logger.error(f"Error adding keyword '{keyword}': {e}") raise + def create_keyword_collection(name, parent_id=None): try: validated_name = validate_collection_name(name) @@ -235,6 +389,7 @@ def create_keyword_collection(name, parent_id=None): logger.error(f"Error creating keyword collection '{name}': {e}") raise + def add_keyword_to_collection(collection_name, keyword): try: validated_collection_name = validate_collection_name(collection_name) @@ -259,6 +414,7 @@ def add_keyword_to_collection(collection_name, keyword): logger.error(f"Error adding keyword '{keyword}' to collection '{collection_name}': {e}") raise + def add_keywords_to_conversation(conversation_id, keywords): if not isinstance(keywords, (list, tuple)): raise ValueError("Keywords must be a list or tuple") @@ -282,6 +438,23 @@ def add_keywords_to_conversation(conversation_id, keywords): logger.error(f"Error adding keywords to conversation '{conversation_id}': {e}") raise + +def view_rag_keywords(): + try: + rag_db_path = get_rag_qa_db_path() + with sqlite3.connect(rag_db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT keyword FROM rag_qa_keywords ORDER BY keyword") + keywords = cursor.fetchall() + if keywords: + keyword_list = [k[0] for k in keywords] + return "### Current RAG QA Keywords:\n" + "\n".join( + [f"- {k}" for k in keyword_list]) + return "No keywords found." + except Exception as e: + return f"Error retrieving keywords: {str(e)}" + + def get_keywords_for_conversation(conversation_id): try: query = ''' @@ -298,6 +471,7 @@ def get_keywords_for_conversation(conversation_id): logger.error(f"Error getting keywords for conversation '{conversation_id}': {e}") raise + def get_keywords_for_collection(collection_name): try: query = ''' @@ -315,6 +489,116 @@ def get_keywords_for_collection(collection_name): logger.error(f"Error getting keywords for collection '{collection_name}': {e}") raise + +def delete_rag_keyword(keyword: str) -> str: + """ + Delete a keyword from the RAG QA database and all its associations. + + Args: + keyword (str): The keyword to delete + + Returns: + str: Success/failure message + """ + try: + # Validate the keyword + validated_keyword = validate_keyword(keyword) + + with transaction() as conn: + # First, get the keyword ID + cursor = conn.cursor() + cursor.execute("SELECT id FROM rag_qa_keywords WHERE keyword = ?", (validated_keyword,)) + result = cursor.fetchone() + + if not result: + return f"Keyword '{validated_keyword}' not found." + + keyword_id = result[0] + + # Delete from all associated tables + cursor.execute("DELETE FROM rag_qa_conversation_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_collection_keywords WHERE keyword_id = ?", (keyword_id,)) + cursor.execute("DELETE FROM rag_qa_note_keywords WHERE keyword_id = ?", (keyword_id,)) + + # Finally, delete the keyword itself + cursor.execute("DELETE FROM rag_qa_keywords WHERE id = ?", (keyword_id,)) + + logger.info(f"Keyword '{validated_keyword}' deleted successfully") + return f"Successfully deleted keyword '{validated_keyword}' and all its associations." + + except ValueError as e: + error_msg = f"Invalid keyword: {str(e)}" + logger.error(error_msg) + return error_msg + except Exception as e: + error_msg = f"Error deleting keyword: {str(e)}" + logger.error(error_msg) + return error_msg + + +def export_rag_keywords_to_csv() -> Tuple[str, str]: + """ + Export all RAG QA keywords to a CSV file. + + Returns: + Tuple[str, str]: (status_message, file_path) + """ + import csv + from tempfile import NamedTemporaryFile + from datetime import datetime + + try: + # Create a temporary CSV file + temp_file = NamedTemporaryFile(mode='w+', delete=False, suffix='.csv', newline='') + + with transaction() as conn: + cursor = conn.cursor() + + # Get all keywords and their associations + query = """ + SELECT + k.keyword, + GROUP_CONCAT(DISTINCT c.name) as collections, + COUNT(DISTINCT ck.conversation_id) as num_conversations, + COUNT(DISTINCT nk.note_id) as num_notes + FROM rag_qa_keywords k + LEFT JOIN rag_qa_collection_keywords col_k ON k.id = col_k.keyword_id + LEFT JOIN rag_qa_keyword_collections c ON col_k.collection_id = c.id + LEFT JOIN rag_qa_conversation_keywords ck ON k.id = ck.keyword_id + LEFT JOIN rag_qa_note_keywords nk ON k.id = nk.keyword_id + GROUP BY k.id, k.keyword + ORDER BY k.keyword + """ + + cursor.execute(query) + results = cursor.fetchall() + + # Write to CSV + writer = csv.writer(temp_file) + writer.writerow(['Keyword', 'Collections', 'Number of Conversations', 'Number of Notes']) + + for row in results: + writer.writerow([ + row[0], # keyword + row[1] if row[1] else '', # collections (may be None) + row[2], # num_conversations + row[3] # num_notes + ]) + + temp_file.close() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + status_msg = f"Successfully exported {len(results)} keywords to CSV." + logger.info(status_msg) + + return status_msg, temp_file.name + + except Exception as e: + error_msg = f"Error exporting keywords: {str(e)}" + logger.error(error_msg) + return error_msg, "" + + # # End of Keyword-related functions ################################################### @@ -339,6 +623,7 @@ def save_notes(conversation_id, title, content): logger.error(f"Error saving notes for conversation '{conversation_id}': {e}") raise + def update_note(note_id, title, content): try: query = "UPDATE rag_qa_notes SET title = ?, content = ?, timestamp = ? WHERE id = ?" @@ -349,6 +634,121 @@ def update_note(note_id, title, content): logger.error(f"Error updating note ID '{note_id}': {e}") raise + +def search_notes_titles(search_term: str, page: int = 1, results_per_page: int = 20, connection=None) -> Tuple[ + List[Tuple], int, int]: + """ + Search note titles using full-text search. Returns all notes if search_term is empty. + + Args: + search_term (str): The search term for note titles. If empty, returns all notes. + page (int, optional): Page number for pagination. Defaults to 1. + results_per_page (int, optional): Number of results per page. Defaults to 20. + connection (sqlite3.Connection, optional): Database connection. Uses new connection if not provided. + + Returns: + Tuple[List[Tuple], int, int]: Tuple containing: + - List of tuples: (note_id, title, content, timestamp, conversation_id) + - Total number of pages + - Total count of matching records + + Raises: + ValueError: If page number is less than 1 + sqlite3.Error: If there's a database error + """ + if page < 1: + raise ValueError("Page number must be 1 or greater.") + + offset = (page - 1) * results_per_page + + def execute_search(conn): + cursor = conn.cursor() + + # Debug: Show table contents + cursor.execute("SELECT title FROM rag_qa_notes") + main_titles = cursor.fetchall() + logger.debug(f"Main table titles: {main_titles}") + + cursor.execute("SELECT title FROM rag_qa_notes_fts") + fts_titles = cursor.fetchall() + logger.debug(f"FTS table titles: {fts_titles}") + + if not search_term.strip(): + # Query for all notes + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes + """ + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT id, title, content, timestamp, conversation_id + FROM rag_qa_notes + ORDER BY timestamp DESC + LIMIT ? OFFSET ? + """, + (results_per_page, offset) + ) + results = cursor.fetchall() + else: + # Search query + search_term_clean = search_term.strip().lower() + + # Test direct FTS search + cursor.execute( + """ + SELECT COUNT(*) + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + """, + (search_term_clean,) + ) + total_count = cursor.fetchone()[0] + + cursor.execute( + """ + SELECT + n.id, + n.title, + n.content, + n.timestamp, + n.conversation_id + FROM rag_qa_notes n + JOIN rag_qa_notes_fts fts ON n.id = fts.rowid + WHERE fts.title MATCH ? + ORDER BY rank + LIMIT ? OFFSET ? + """, + (search_term_clean, results_per_page, offset) + ) + results = cursor.fetchall() + + logger.debug(f"Search term: {search_term_clean}") + logger.debug(f"Results: {results}") + + total_pages = max(1, (total_count + results_per_page - 1) // results_per_page) + logger.info(f"Found {total_count} matching notes for search term '{search_term}'") + + return results, total_pages, total_count + + try: + if connection: + return execute_search(connection) + else: + with get_db_connection() as conn: + return execute_search(conn) + + except sqlite3.Error as e: + logger.error(f"Database error in search_notes_titles: {str(e)}") + logger.error(f"Search term: {search_term}") + raise sqlite3.Error(f"Error searching notes: {str(e)}") + + + def get_notes(conversation_id): """Retrieve notes for a given conversation.""" try: @@ -361,6 +761,7 @@ def get_notes(conversation_id): logger.error(f"Error getting notes for conversation '{conversation_id}': {e}") raise + def get_note_by_id(note_id): try: query = "SELECT id, title, content FROM rag_qa_notes WHERE id = ?" @@ -370,9 +771,21 @@ def get_note_by_id(note_id): logger.error(f"Error getting note by ID '{note_id}': {e}") raise + def get_notes_by_keywords(keywords, page=1, page_size=20): try: - placeholders = ','.join(['?'] * len(keywords)) + # Handle empty or invalid keywords + if not keywords or not isinstance(keywords, (list, tuple)) or len(keywords) == 0: + return [], 0, 0 + + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + + # If no valid keywords after cleaning, return empty result + if not clean_keywords: + return [], 0, 0 + + placeholders = ','.join(['?'] * len(clean_keywords)) query = f''' SELECT n.id, n.title, n.content, n.timestamp FROM rag_qa_notes n @@ -381,14 +794,15 @@ def get_notes_by_keywords(keywords, page=1, page_size=20): WHERE k.keyword IN ({placeholders}) ORDER BY n.timestamp DESC ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") + results, total_pages, total_count = get_paginated_results(query, tuple(clean_keywords), page, page_size) + logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(clean_keywords)} (page {page} of {total_pages})") notes = [(row[0], row[1], row[2], row[3]) for row in results] return notes, total_pages, total_count except Exception as e: logger.error(f"Error getting notes by keywords: {e}") raise + def get_notes_by_keyword_collection(collection_name, page=1, page_size=20): try: query = ''' @@ -501,9 +915,10 @@ def delete_note(note_id): # # Chat-related functions -def save_message(conversation_id, role, content): +def save_message(conversation_id, role, content, timestamp=None): try: - timestamp = datetime.now().isoformat() + if timestamp is None: + timestamp = datetime.now().isoformat() query = "INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) VALUES (?, ?, ?, ?)" execute_query(query, (conversation_id, timestamp, role, content)) @@ -516,29 +931,103 @@ def save_message(conversation_id, role, content): logger.error(f"Error saving message for conversation '{conversation_id}': {e}") raise -def start_new_conversation(title="Untitled Conversation"): + +def start_new_conversation(title="Untitled Conversation", media_id=None): try: conversation_id = str(uuid.uuid4()) - query = "INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title) VALUES (?, ?, ?, ?)" + query = """ + INSERT INTO conversation_metadata ( + conversation_id, created_at, last_updated, title, media_id, rating + ) VALUES (?, ?, ?, ?, ?, ?) + """ now = datetime.now().isoformat() - execute_query(query, (conversation_id, now, now, title)) - logger.info(f"New conversation '{conversation_id}' started with title '{title}'") + # Set initial rating to NULL + execute_query(query, (conversation_id, now, now, title, media_id, None)) + logger.info(f"New conversation '{conversation_id}' started with title '{title}' and media_id '{media_id}'") return conversation_id except Exception as e: logger.error(f"Error starting new conversation: {e}") raise + def get_all_conversations(page=1, page_size=20): try: - query = "SELECT conversation_id, title FROM conversation_metadata ORDER BY last_updated DESC" - results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size) - conversations = [(row[0], row[1]) for row in results] - logger.info(f"Retrieved {len(conversations)} conversations (page {page} of {total_pages})") - return conversations, total_pages, total_count + query = """ + SELECT conversation_id, title, media_id, rating + FROM conversation_metadata + ORDER BY last_updated DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM conversation_metadata" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + conversations = [{ + 'conversation_id': row[0], + 'title': row[1], + 'media_id': row[2], + 'rating': row[3] # Include rating + } for row in results] + return conversations, total_pages, total_count + except Exception as e: + logging.error(f"Error getting conversations: {e}") + raise + + +def get_all_notes(page=1, page_size=20): + try: + query = """ + SELECT n.id, n.conversation_id, n.title, n.content, n.timestamp, + cm.title as conversation_title, cm.media_id + FROM rag_qa_notes n + LEFT JOIN conversation_metadata cm ON n.conversation_id = cm.conversation_id + ORDER BY n.timestamp DESC + LIMIT ? OFFSET ? + """ + + count_query = "SELECT COUNT(*) FROM rag_qa_notes" + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + # Get total count + cursor.execute(count_query) + total_count = cursor.fetchone()[0] + total_pages = (total_count + page_size - 1) // page_size + + # Get page of results + offset = (page - 1) * page_size + cursor.execute(query, (page_size, offset)) + results = cursor.fetchall() + + notes = [{ + 'id': row[0], + 'conversation_id': row[1], + 'title': row[2], + 'content': row[3], + 'timestamp': row[4], + 'conversation_title': row[5], + 'media_id': row[6] + } for row in results] + + return notes, total_pages, total_count except Exception as e: - logger.error(f"Error getting conversations: {e}") + logging.error(f"Error getting notes: {e}") raise + # Pagination helper function def get_paginated_results(query, params=None, page=1, page_size=20): try: @@ -564,6 +1053,7 @@ def get_paginated_results(query, params=None, page=1, page_size=20): logger.error(f"Error retrieving paginated results: {e}") raise + def get_all_collections(page=1, page_size=20): try: query = "SELECT name FROM rag_qa_keyword_collections" @@ -575,24 +1065,79 @@ def get_all_collections(page=1, page_size=20): logger.error(f"Error getting collections: {e}") raise -def search_conversations_by_keywords(keywords, page=1, page_size=20): + +def search_conversations_by_keywords(keywords=None, title_query=None, content_query=None, page=1, page_size=20): try: - placeholders = ','.join(['?' for _ in keywords]) - query = f''' - SELECT DISTINCT cm.conversation_id, cm.title + # Base query starts with conversation metadata + query = """ + SELECT DISTINCT cm.conversation_id, cm.title, cm.last_updated FROM conversation_metadata cm - JOIN rag_qa_conversation_keywords ck ON cm.conversation_id = ck.conversation_id - JOIN rag_qa_keywords k ON ck.keyword_id = k.id - WHERE k.keyword IN ({placeholders}) - ''' - results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size) - logger.info( - f"Found {total_count} conversations matching keywords: {', '.join(keywords)} (page {page} of {total_pages})") - return results, total_pages, total_count + WHERE 1=1 + """ + params = [] + + # Add content search if provided + if content_query and isinstance(content_query, str) and content_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM rag_qa_chats_fts + WHERE rag_qa_chats_fts.content MATCH ? + AND rag_qa_chats_fts.rowid IN ( + SELECT id FROM rag_qa_chats + WHERE conversation_id = cm.conversation_id + ) + ) + """ + params.append(content_query.strip()) + + # Add title search if provided + if title_query and isinstance(title_query, str) and title_query.strip(): + query += """ + AND EXISTS ( + SELECT 1 FROM conversation_metadata_fts + WHERE conversation_metadata_fts.title MATCH ? + AND conversation_metadata_fts.rowid = cm.rowid + ) + """ + params.append(title_query.strip()) + + # Add keyword search if provided + if keywords and isinstance(keywords, (list, tuple)) and len(keywords) > 0: + # Convert all keywords to strings and strip them + clean_keywords = [str(k).strip() for k in keywords if k is not None and str(k).strip()] + if clean_keywords: # Only add to query if we have valid keywords + placeholders = ','.join(['?' for _ in clean_keywords]) + query += f""" + AND EXISTS ( + SELECT 1 FROM rag_qa_conversation_keywords ck + JOIN rag_qa_keywords k ON ck.keyword_id = k.id + WHERE ck.conversation_id = cm.conversation_id + AND k.keyword IN ({placeholders}) + ) + """ + params.extend(clean_keywords) + + # Add ordering + query += " ORDER BY cm.last_updated DESC" + + results, total_pages, total_count = get_paginated_results(query, tuple(params), page, page_size) + + conversations = [ + { + 'conversation_id': row[0], + 'title': row[1], + 'last_updated': row[2] + } + for row in results + ] + + return conversations, total_pages, total_count + except Exception as e: - logger.error(f"Error searching conversations by keywords {keywords}: {e}") + logger.error(f"Error searching conversations: {e}") raise + def load_chat_history(conversation_id, page=1, page_size=50): try: query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp" @@ -604,6 +1149,7 @@ def load_chat_history(conversation_id, page=1, page_size=50): logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}") raise + def update_conversation_title(conversation_id, new_title): """Update the title of a conversation.""" try: @@ -614,6 +1160,59 @@ def update_conversation_title(conversation_id, new_title): logger.error(f"Error updating conversation title: {e}") raise + +def delete_messages_in_conversation(conversation_id): + """Helper function to delete all messages in a conversation.""" + try: + execute_query("DELETE FROM rag_qa_chats WHERE conversation_id = ?", (conversation_id,)) + logging.info(f"Messages in conversation '{conversation_id}' deleted successfully.") + except Exception as e: + logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") + raise + + +def get_conversation_title(conversation_id): + """Helper function to get the conversation title.""" + query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + return result[0][0] + else: + return "Untitled Conversation" + + +def get_conversation_text(conversation_id): + try: + query = """ + SELECT role, content + FROM rag_qa_chats + WHERE conversation_id = ? + ORDER BY timestamp ASC + """ + + messages = [] + # Use the connection as a context manager + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, (conversation_id,)) + messages = cursor.fetchall() + + return "\n\n".join([f"{msg[0]}: {msg[1]}" for msg in messages]) + except Exception as e: + logger.error(f"Error getting conversation text: {e}") + raise + + +def get_conversation_details(conversation_id): + query = "SELECT title, media_id, rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + return {'title': result[0][0], 'media_id': result[0][1], 'rating': result[0][2]} + else: + return {'title': "Untitled Conversation", 'media_id': None, 'rating': None} + + def delete_conversation(conversation_id): """Delete a conversation and its associated messages and notes.""" try: @@ -633,11 +1232,203 @@ def delete_conversation(conversation_id): logger.error(f"Error deleting conversation '{conversation_id}': {e}") raise +def set_conversation_rating(conversation_id, rating): + """Set the rating for a conversation.""" + # Validate rating + if rating not in [1, 2, 3]: + raise ValueError('Rating must be an integer between 1 and 3.') + try: + query = "UPDATE conversation_metadata SET rating = ? WHERE conversation_id = ?" + execute_query(query, (rating, conversation_id)) + logger.info(f"Rating for conversation '{conversation_id}' set to {rating}") + except Exception as e: + logger.error(f"Error setting rating for conversation '{conversation_id}': {e}") + raise + +def get_conversation_rating(conversation_id): + """Get the rating of a conversation.""" + try: + query = "SELECT rating FROM conversation_metadata WHERE conversation_id = ?" + result = execute_query(query, (conversation_id,)) + if result: + rating = result[0][0] + logger.info(f"Rating for conversation '{conversation_id}' is {rating}") + return rating + else: + logger.warning(f"Conversation '{conversation_id}' not found.") + return None + except Exception as e: + logger.error(f"Error getting rating for conversation '{conversation_id}': {e}") + raise + + +def get_conversation_name(conversation_id: str) -> str: + """ + Retrieves the title/name of a conversation from the conversation_metadata table. + + Args: + conversation_id (str): The unique identifier of the conversation + + Returns: + str: The title of the conversation if found, "Untitled Conversation" if not found + + Raises: + sqlite3.Error: If there's a database error + """ + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT title FROM conversation_metadata WHERE conversation_id = ?", + (conversation_id,) + ) + result = cursor.fetchone() + + if result: + return result[0] + else: + logging.warning(f"No conversation found with ID: {conversation_id}") + return "Untitled Conversation" + + except sqlite3.Error as e: + logging.error(f"Database error retrieving conversation name for ID {conversation_id}: {e}") + raise + except Exception as e: + logging.error(f"Unexpected error retrieving conversation name for ID {conversation_id}: {e}") + raise + + +def search_rag_chat(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Chat database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_chats.id, rag_qa_chats.conversation_id, rag_qa_chats.role, rag_qa_chats.content + FROM rag_qa_chats_fts + JOIN rag_qa_chats ON rag_qa_chats_fts.rowid = rag_qa_chats.id + WHERE rag_qa_chats_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "conversation_id": r['conversation_id'], + "role": r['role'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_chat: {e}") + return [] + + +def search_rag_notes(query: str, fts_top_k: int = 10, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: + """ + Perform a full-text search on the RAG Notes database. + + Args: + query: Search query string. + fts_top_k: Maximum number of results to return. + relevant_media_ids: Optional list of media IDs to filter results. + + Returns: + List of search results with content and metadata. + """ + if not query.strip(): + return [] + + try: + db_path = get_rag_qa_db_path() + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + # Perform the full-text search using the FTS virtual table + cursor.execute(""" + SELECT rag_qa_notes.id, rag_qa_notes.title, rag_qa_notes.content, rag_qa_notes.conversation_id + FROM rag_qa_notes_fts + JOIN rag_qa_notes ON rag_qa_notes_fts.rowid = rag_qa_notes.id + WHERE rag_qa_notes_fts MATCH ? + LIMIT ? + """, (query, fts_top_k)) + + rows = cursor.fetchall() + columns = [description[0] for description in cursor.description] + results = [dict(zip(columns, row)) for row in rows] + + # Filter by relevant_media_ids if provided + if relevant_media_ids is not None: + results = [ + r for r in results + if get_conversation_details(r['conversation_id']).get('media_id') in relevant_media_ids + ] + + # Format results + formatted_results = [ + { + "content": r['content'], + "metadata": { + "note_id": r['id'], + "title": r['title'], + "conversation_id": r['conversation_id'], + "media_id": get_conversation_details(r['conversation_id']).get('media_id') + } + } + for r in results + ] + return formatted_results + + except Exception as e: + logging.error(f"Error in search_rag_notes: {e}") + return [] + # # End of Chat-related functions ################################################### +################################################### +# +# Import functions + + +# +# End of Import functions +################################################### + + ################################################### # # Functions to export DB data diff --git a/App_Function_Libraries/DB/SQLite_DB.py b/App_Function_Libraries/DB/SQLite_DB.py index 2c05cbb86041120cc0c277aebeb2e45e42eb9050..66adda99721d1ab4a1f6ad0350ce0340eed6ba8c 100644 --- a/App_Function_Libraries/DB/SQLite_DB.py +++ b/App_Function_Libraries/DB/SQLite_DB.py @@ -21,7 +21,7 @@ import configparser # 11. browse_items(search_query, search_type) # 12. fetch_item_details(media_id: int) # 13. add_media_version(media_id: int, prompt: str, summary: str) -# 14. search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) +# 14. search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10) # 15. search_and_display(search_query, search_fields, keywords, page) # 16. display_details(index, results) # 17. get_details(index, dataframe) @@ -55,12 +55,14 @@ import re import shutil import sqlite3 import threading +import time import traceback from contextlib import contextmanager from datetime import datetime, timedelta from typing import List, Tuple, Dict, Any, Optional from urllib.parse import quote +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # Local Libraries from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path, \ get_database_dir @@ -342,27 +344,6 @@ def create_tables(db) -> None: ) ''', ''' - CREATE TABLE IF NOT EXISTS ChatConversations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - media_id INTEGER, - media_name TEXT, - conversation_name TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (media_id) REFERENCES Media(id) - ) - ''', - ''' - CREATE TABLE IF NOT EXISTS ChatMessages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - conversation_id INTEGER, - sender TEXT, - message TEXT, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (conversation_id) REFERENCES ChatConversations(id) - ) - ''', - ''' CREATE TABLE IF NOT EXISTS Transcripts ( id INTEGER PRIMARY KEY AUTOINCREMENT, media_id INTEGER, @@ -421,8 +402,6 @@ def create_tables(db) -> None: 'CREATE INDEX IF NOT EXISTS idx_mediakeywords_keyword_id ON MediaKeywords(keyword_id)', 'CREATE INDEX IF NOT EXISTS idx_media_version_media_id ON MediaVersion(media_id)', 'CREATE INDEX IF NOT EXISTS idx_mediamodifications_media_id ON MediaModifications(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatconversations_media_id ON ChatConversations(media_id)', - 'CREATE INDEX IF NOT EXISTS idx_chatmessages_conversation_id ON ChatMessages(conversation_id)', 'CREATE INDEX IF NOT EXISTS idx_media_is_trash ON Media(is_trash)', 'CREATE INDEX IF NOT EXISTS idx_mediachunks_media_id ON MediaChunks(media_id)', 'CREATE INDEX IF NOT EXISTS idx_unvectorized_media_chunks_media_id ON UnvectorizedMediaChunks(media_id)', @@ -606,7 +585,10 @@ def mark_media_as_processed(database, media_id): # Function to add media with keywords def add_media_with_keywords(url, title, media_type, content, keywords, prompt, summary, transcription_model, author, ingestion_date): + log_counter("add_media_with_keywords_attempt") + start_time = time.time() logging.debug(f"Entering add_media_with_keywords: URL={url}, Title={title}") + # Set default values for missing fields if url is None: url = 'localhost' @@ -622,10 +604,17 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s author = author or 'Unknown' ingestion_date = ingestion_date or datetime.now().strftime('%Y-%m-%d') - if media_type not in ['article', 'audio', 'document', 'mediawiki_article', 'mediawiki_dump', 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: - raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note podcast, text, video, unknown.") + if media_type not in ['article', 'audio', 'book', 'document', 'mediawiki_article', 'mediawiki_dump', + 'obsidian_note', 'podcast', 'text', 'video', 'unknown']: + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidMediaType"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + raise InputError("Invalid media type. Allowed types: article, audio file, document, obsidian_note, podcast, text, video, unknown.") if ingestion_date and not is_valid_date(ingestion_date): + log_counter("add_media_with_keywords_error", labels={"error_type": "InvalidDateFormat"}) + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) raise InputError("Invalid ingestion date format. Use YYYY-MM-DD.") # Handle keywords as either string or list @@ -654,6 +643,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s logging.debug(f"Existing media ID for {url}: {existing_media_id}") if existing_media_id: + # Update existing media media_id = existing_media_id logging.debug(f"Updating existing media with ID: {media_id}") cursor.execute(''' @@ -661,7 +651,9 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s SET content = ?, transcription_model = ?, type = ?, author = ?, ingestion_date = ? WHERE id = ? ''', (content, transcription_model, media_type, author, ingestion_date, media_id)) + log_counter("add_media_with_keywords_update") else: + # Insert new media logging.debug("Inserting new media") cursor.execute(''' INSERT INTO Media (url, title, type, content, author, ingestion_date, transcription_model) @@ -669,6 +661,7 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s ''', (url, title, media_type, content, author, ingestion_date, transcription_model)) media_id = cursor.lastrowid logging.debug(f"New media inserted with ID: {media_id}") + log_counter("add_media_with_keywords_insert") cursor.execute(''' INSERT INTO MediaModifications (media_id, prompt, summary, modification_date) @@ -698,13 +691,23 @@ def add_media_with_keywords(url, title, media_type, content, keywords, prompt, s conn.commit() logging.info(f"Media '{title}' successfully added/updated with ID: {media_id}") - return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_success") + + return media_id, f"Media '{title}' added/updated successfully with keywords: {', '.join(keyword_list)}" except sqlite3.Error as e: logging.error(f"SQL Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding media with keywords: {e}") except Exception as e: logging.error(f"Unexpected Error in add_media_with_keywords: {e}") + duration = time.time() - start_time + log_histogram("add_media_with_keywords_duration", duration) + log_counter("add_media_with_keywords_error", labels={"error_type": type(e).__name__}) raise DatabaseError(f"Unexpected error: {e}") @@ -779,7 +782,13 @@ def ingest_article_to_db(url, title, author, content, keywords, summary, ingesti # Function to add a keyword def add_keyword(keyword: str) -> int: + log_counter("add_keyword_attempt") + start_time = time.time() + if not keyword.strip(): + log_counter("add_keyword_error", labels={"error_type": "EmptyKeyword"}) + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) raise DatabaseError("Keyword cannot be empty") keyword = keyword.strip().lower() @@ -801,18 +810,32 @@ def add_keyword(keyword: str) -> int: logging.info(f"Keyword '{keyword}' added or updated with ID: {keyword_id}") conn.commit() + + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_success") + return keyword_id except sqlite3.IntegrityError as e: logging.error(f"Integrity error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "IntegrityError"}) raise DatabaseError(f"Integrity error adding keyword: {e}") except sqlite3.Error as e: logging.error(f"Error adding keyword: {e}") + duration = time.time() - start_time + log_histogram("add_keyword_duration", duration) + log_counter("add_keyword_error", labels={"error_type": "SQLiteError"}) raise DatabaseError(f"Error adding keyword: {e}") # Function to delete a keyword def delete_keyword(keyword: str) -> str: + log_counter("delete_keyword_attempt") + start_time = time.time() + keyword = keyword.strip().lower() with db.get_connection() as conn: cursor = conn.cursor() @@ -823,10 +846,23 @@ def delete_keyword(keyword: str) -> str: cursor.execute('DELETE FROM Keywords WHERE keyword = ?', (keyword,)) cursor.execute('DELETE FROM keyword_fts WHERE rowid = ?', (keyword_id[0],)) conn.commit() + + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_success") + return f"Keyword '{keyword}' deleted successfully." else: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_not_found") + return f"Keyword '{keyword}' not found." except sqlite3.Error as e: + duration = time.time() - start_time + log_histogram("delete_keyword_duration", duration) + log_counter("delete_keyword_error", labels={"error_type": type(e).__name__}) + logging.error(f"Error deleting keyword: {e}") raise DatabaseError(f"Error deleting keyword: {e}") @@ -1000,7 +1036,7 @@ def add_media_version(conn, media_id: int, prompt: str, summary: str) -> None: # Function to search the database with advanced options, including keyword search and full-text search -def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10, connection=None): +def search_media_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 20, connection=None): if page < 1: raise ValueError("Page number must be 1 or greater.") @@ -1055,7 +1091,7 @@ def sqlite_search_db(search_query: str, search_fields: List[str], keywords: str, # Gradio function to handle user input and display results with pagination, with better feedback def search_and_display(search_query, search_fields, keywords, page): - results = sqlite_search_db(search_query, search_fields, keywords, page) + results = search_media_db(search_query, search_fields, keywords, page) if isinstance(results, pd.DataFrame): # Convert DataFrame to a list of tuples or lists @@ -1133,7 +1169,7 @@ def format_results(results): # Function to export search results to CSV or markdown with pagination def export_to_file(search_query: str, search_fields: List[str], keyword: str, page: int = 1, results_per_file: int = 1000, export_format: str = 'csv'): try: - results = sqlite_search_db(search_query, search_fields, keyword, page, results_per_file) + results = search_media_db(search_query, search_fields, keyword, page, results_per_file) if not results: return "No results found to export." @@ -1381,303 +1417,6 @@ def schedule_chunking(media_id: int, content: str, media_name: str): ####################################################################################################################### -####################################################################################################################### -# -# Functions to manage prompts DB - -def create_prompts_db(): - logging.debug("create_prompts_db: Creating prompts database.") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.executescript(''' - CREATE TABLE IF NOT EXISTS Prompts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - author TEXT, - details TEXT, - system TEXT, - user TEXT - ); - CREATE TABLE IF NOT EXISTS Keywords ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - keyword TEXT NOT NULL UNIQUE COLLATE NOCASE - ); - CREATE TABLE IF NOT EXISTS PromptKeywords ( - prompt_id INTEGER, - keyword_id INTEGER, - FOREIGN KEY (prompt_id) REFERENCES Prompts (id), - FOREIGN KEY (keyword_id) REFERENCES Keywords (id), - PRIMARY KEY (prompt_id, keyword_id) - ); - CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON Keywords(keyword); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_prompt_id ON PromptKeywords(prompt_id); - CREATE INDEX IF NOT EXISTS idx_promptkeywords_keyword_id ON PromptKeywords(keyword_id); - ''') - -# FIXME - dirty hack that should be removed later... -# Migration function to add the 'author' column to the Prompts table -def add_author_column_to_prompts(): - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - # Check if 'author' column already exists - cursor.execute("PRAGMA table_info(Prompts)") - columns = [col[1] for col in cursor.fetchall()] - - if 'author' not in columns: - # Add the 'author' column - cursor.execute('ALTER TABLE Prompts ADD COLUMN author TEXT') - print("Author column added to Prompts table.") - else: - print("Author column already exists in Prompts table.") - -add_author_column_to_prompts() - -def normalize_keyword(keyword): - return re.sub(r'\s+', ' ', keyword.strip().lower()) - - -# FIXME - update calls to this function to use the new args -def add_prompt(name, author, details, system=None, user=None, keywords=None): - logging.debug(f"add_prompt: Adding prompt with name: {name}, author: {author}, system: {system}, user: {user}, keywords: {keywords}") - if not name: - logging.error("add_prompt: A name is required.") - return "A name is required." - - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO Prompts (name, author, details, system, user) - VALUES (?, ?, ?, ?, ?) - ''', (name, author, details, system, user)) - prompt_id = cursor.lastrowid - - if keywords: - normalized_keywords = [normalize_keyword(k) for k in keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute(''' - INSERT OR IGNORE INTO Keywords (keyword) VALUES (?) - ''', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute(''' - INSERT OR IGNORE INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?) - ''', (prompt_id, keyword_id)) - return "Prompt added successfully." - except sqlite3.IntegrityError: - return "Prompt with this name already exists." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def fetch_prompt_details(name): - logging.debug(f"fetch_prompt_details: Fetching details for prompt: {name}") - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT p.name, p.author, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name = ? - GROUP BY p.id - ''', (name,)) - return cursor.fetchone() - - -def list_prompts(page=1, per_page=10): - logging.debug(f"list_prompts: Listing prompts for page {page} with {per_page} prompts per page.") - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts LIMIT ? OFFSET ?', (per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of prompts - cursor.execute('SELECT COUNT(*) FROM Prompts') - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - -# This will not scale. For a large number of prompts, use a more efficient method. -# FIXME - see above statement. -def load_preset_prompts(): - logging.debug("load_preset_prompts: Loading preset prompts.") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute('SELECT name FROM Prompts ORDER BY name ASC') - prompts = [row[0] for row in cursor.fetchall()] - return prompts - except sqlite3.Error as e: - print(f"Database error: {e}") - return [] - - -def insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords=None): - return add_prompt(title, author, description, system_prompt, user_prompt, keywords) - - -def get_prompt_db_connection(): - prompt_db_path = get_database_path('prompts.db') - return sqlite3.connect(prompt_db_path) - - -def search_prompts(query): - logging.debug(f"search_prompts: Searching prompts with query: {query}") - try: - with get_prompt_db_connection() as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT p.name, p.details, p.system, p.user, GROUP_CONCAT(k.keyword, ', ') as keywords - FROM Prompts p - LEFT JOIN PromptKeywords pk ON p.id = pk.prompt_id - LEFT JOIN Keywords k ON pk.keyword_id = k.id - WHERE p.name LIKE ? OR p.details LIKE ? OR p.system LIKE ? OR p.user LIKE ? OR k.keyword LIKE ? - GROUP BY p.id - ORDER BY p.name - """, (f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%', f'%{query}%')) - return cursor.fetchall() - except sqlite3.Error as e: - logging.error(f"Error searching prompts: {e}") - return [] - - -def search_prompts_by_keyword(keyword, page=1, per_page=10): - logging.debug(f"search_prompts_by_keyword: Searching prompts by keyword: {keyword}") - normalized_keyword = normalize_keyword(keyword) - offset = (page - 1) * per_page - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT DISTINCT p.name - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - LIMIT ? OFFSET ? - ''', ('%' + normalized_keyword + '%', per_page, offset)) - prompts = [row[0] for row in cursor.fetchall()] - - # Get total count of matching prompts - cursor.execute(''' - SELECT COUNT(DISTINCT p.id) - FROM Prompts p - JOIN PromptKeywords pk ON p.id = pk.prompt_id - JOIN Keywords k ON pk.keyword_id = k.id - WHERE k.keyword LIKE ? - ''', ('%' + normalized_keyword + '%',)) - total_count = cursor.fetchone()[0] - - total_pages = (total_count + per_page - 1) // per_page - return prompts, total_pages, page - - -def update_prompt_keywords(prompt_name, new_keywords): - logging.debug(f"update_prompt_keywords: Updating keywords for prompt: {prompt_name}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - cursor.execute('SELECT id FROM Prompts WHERE name = ?', (prompt_name,)) - prompt_id = cursor.fetchone() - if not prompt_id: - return "Prompt not found." - prompt_id = prompt_id[0] - - cursor.execute('DELETE FROM PromptKeywords WHERE prompt_id = ?', (prompt_id,)) - - normalized_keywords = [normalize_keyword(k) for k in new_keywords if k.strip()] - for keyword in set(normalized_keywords): # Use set to remove duplicates - cursor.execute('INSERT OR IGNORE INTO Keywords (keyword) VALUES (?)', (keyword,)) - cursor.execute('SELECT id FROM Keywords WHERE keyword = ?', (keyword,)) - keyword_id = cursor.fetchone()[0] - cursor.execute('INSERT INTO PromptKeywords (prompt_id, keyword_id) VALUES (?, ?)', - (prompt_id, keyword_id)) - - # Remove unused keywords - cursor.execute(''' - DELETE FROM Keywords - WHERE id NOT IN (SELECT DISTINCT keyword_id FROM PromptKeywords) - ''') - return "Keywords updated successfully." - except sqlite3.Error as e: - return f"Database error: {e}" - - -def add_or_update_prompt(title, author, description, system_prompt, user_prompt, keywords=None): - logging.debug(f"add_or_update_prompt: Adding or updating prompt: {title}") - if not title: - return "Error: Title is required." - - existing_prompt = fetch_prompt_details(title) - if existing_prompt: - # Update existing prompt - result = update_prompt_in_db(title, author, description, system_prompt, user_prompt) - if "successfully" in result: - # Update keywords if the prompt update was successful - keyword_result = update_prompt_keywords(title, keywords or []) - result += f" {keyword_result}" - else: - # Insert new prompt - result = insert_prompt_to_db(title, author, description, system_prompt, user_prompt, keywords) - - return result - - -def load_prompt_details(selected_prompt): - logging.debug(f"load_prompt_details: Loading prompt details for {selected_prompt}") - if selected_prompt: - details = fetch_prompt_details(selected_prompt) - if details: - return details[0], details[1], details[2], details[3], details[4], details[5] - return "", "", "", "", "", "" - - -def update_prompt_in_db(title, author, description, system_prompt, user_prompt): - logging.debug(f"update_prompt_in_db: Updating prompt: {title}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - cursor.execute( - "UPDATE Prompts SET author = ?, details = ?, system = ?, user = ? WHERE name = ?", - (author, description, system_prompt, user_prompt, title) - ) - if cursor.rowcount == 0: - return "No prompt found with the given title." - return "Prompt updated successfully!" - except sqlite3.Error as e: - return f"Error updating prompt: {e}" - - -create_prompts_db() - -def delete_prompt(prompt_id): - logging.debug(f"delete_prompt: Deleting prompt with ID: {prompt_id}") - try: - with sqlite3.connect(get_database_path('prompts.db')) as conn: - cursor = conn.cursor() - - # Delete associated keywords - cursor.execute("DELETE FROM PromptKeywords WHERE prompt_id = ?", (prompt_id,)) - - # Delete the prompt - cursor.execute("DELETE FROM Prompts WHERE id = ?", (prompt_id,)) - - if cursor.rowcount == 0: - return f"No prompt found with ID {prompt_id}" - else: - conn.commit() - return f"Prompt with ID {prompt_id} has been successfully deleted" - except sqlite3.Error as e: - return f"An error occurred: {e}" - -# -# -####################################################################################################################### - - ####################################################################################################################### # # Function to fetch/update media content @@ -2020,204 +1759,6 @@ def import_obsidian_note_to_db(note_data): ####################################################################################################################### -####################################################################################################################### -# -# Chat-related Functions - - - -def create_chat_conversation(media_id, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatConversations (media_id, conversation_name, created_at, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, conversation_name)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error creating chat conversation: {e}") - raise DatabaseError(f"Error creating chat conversation: {e}") - - -def add_chat_message(conversation_id: int, sender: str, message: str) -> int: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message) - VALUES (?, ?, ?) - ''', (conversation_id, sender, message)) - conn.commit() - return cursor.lastrowid - except sqlite3.Error as e: - logging.error(f"Error adding chat message: {e}") - raise DatabaseError(f"Error adding chat message: {e}") - - -def get_chat_messages(conversation_id: int) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT id, sender, message, timestamp - FROM ChatMessages - WHERE conversation_id = ? - ORDER BY timestamp ASC - ''', (conversation_id,)) - messages = cursor.fetchall() - return [ - { - 'id': msg[0], - 'sender': msg[1], - 'message': msg[2], - 'timestamp': msg[3] - } - for msg in messages - ] - except sqlite3.Error as e: - logging.error(f"Error retrieving chat messages: {e}") - raise DatabaseError(f"Error retrieving chat messages: {e}") - - -def search_chat_conversations(search_query: str) -> List[Dict[str, Any]]: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - SELECT cc.id, cc.media_id, cc.conversation_name, cc.created_at, m.title as media_title - FROM ChatConversations cc - LEFT JOIN Media m ON cc.media_id = m.id - WHERE cc.conversation_name LIKE ? OR m.title LIKE ? - ORDER BY cc.updated_at DESC - ''', (f'%{search_query}%', f'%{search_query}%')) - conversations = cursor.fetchall() - return [ - { - 'id': conv[0], - 'media_id': conv[1], - 'conversation_name': conv[2], - 'created_at': conv[3], - 'media_title': conv[4] or "Unknown Media" - } - for conv in conversations - ] - except sqlite3.Error as e: - logging.error(f"Error searching chat conversations: {e}") - return [] - - -def update_chat_message(message_id: int, new_message: str) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute(''' - UPDATE ChatMessages - SET message = ?, timestamp = CURRENT_TIMESTAMP - WHERE id = ? - ''', (new_message, message_id)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error updating chat message: {e}") - raise DatabaseError(f"Error updating chat message: {e}") - - -def delete_chat_message(message_id: int) -> None: - try: - with db.get_connection() as conn: - cursor = conn.cursor() - cursor.execute('DELETE FROM ChatMessages WHERE id = ?', (message_id,)) - conn.commit() - except sqlite3.Error as e: - logging.error(f"Error deleting chat message: {e}") - raise DatabaseError(f"Error deleting chat message: {e}") - - -def save_chat_history_to_database(chatbot, conversation_id, media_id, media_name, conversation_name): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # If conversation_id is None, create a new conversation - if conversation_id is None: - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (media_id, media_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # If conversation exists, update the media_name - cursor.execute(''' - UPDATE ChatConversations - SET media_name = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (media_name, conversation_id)) - - # Save each message in the chatbot history - for i, (user_msg, ai_msg) in enumerate(chatbot): - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'user', user_msg)) - - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, CURRENT_TIMESTAMP) - ''', (conversation_id, 'ai', ai_msg)) - - # Update the conversation's updated_at timestamp - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - conn.commit() - - return conversation_id - except Exception as e: - logging.error(f"Error saving chat history to database: {str(e)}") - raise - - -def get_conversation_name(conversation_id): - if conversation_id is None: - return None - - try: - with sqlite3.connect('media_summary.db') as conn: # Replace with your actual database name - cursor = conn.cursor() - - query = """ - SELECT conversation_name, media_name - FROM ChatConversations - WHERE id = ? - """ - - cursor.execute(query, (conversation_id,)) - result = cursor.fetchone() - - if result: - conversation_name, media_name = result - if conversation_name: - return conversation_name - elif media_name: - return f"{media_name}-chat" - - return None # Return None if no result found - except sqlite3.Error as e: - logging.error(f"Database error in get_conversation_name: {e}") - return None - except Exception as e: - logging.error(f"Unexpected error in get_conversation_name: {e}") - return None - -# -# End of Chat-related Functions -####################################################################################################################### - - ####################################################################################################################### # # Functions to Compare Transcripts @@ -2837,29 +2378,42 @@ def process_chunks(database, chunks: List[Dict], media_id: int, batch_size: int :param media_id: ID of the media these chunks belong to :param batch_size: Number of chunks to process in each batch """ + log_counter("process_chunks_attempt", labels={"media_id": media_id}) + start_time = time.time() total_chunks = len(chunks) processed_chunks = 0 - for i in range(0, total_chunks, batch_size): - batch = chunks[i:i + batch_size] - chunk_data = [ - (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) - for chunk in batch - ] - - try: - database.execute_many( - "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", - chunk_data - ) - processed_chunks += len(batch) - logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") - except Exception as e: - logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") - # Optionally, you could raise an exception here to stop processing - # raise + try: + for i in range(0, total_chunks, batch_size): + batch = chunks[i:i + batch_size] + chunk_data = [ + (media_id, chunk['text'], chunk['start_index'], chunk['end_index']) + for chunk in batch + ] - logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + try: + database.execute_many( + "INSERT INTO MediaChunks (media_id, chunk_text, start_index, end_index) VALUES (?, ?, ?, ?)", + chunk_data + ) + processed_chunks += len(batch) + logging.info(f"Processed {processed_chunks}/{total_chunks} chunks for media_id {media_id}") + log_counter("process_chunks_batch_success", labels={"media_id": media_id}) + except Exception as e: + logging.error(f"Error inserting chunk batch for media_id {media_id}: {e}") + log_counter("process_chunks_batch_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + # Optionally, you could raise an exception here to stop processing + # raise + + logging.info(f"Finished processing all {total_chunks} chunks for media_id {media_id}") + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_success", labels={"media_id": media_id}) + except Exception as e: + duration = time.time() - start_time + log_histogram("process_chunks_duration", duration, labels={"media_id": media_id}) + log_counter("process_chunks_error", labels={"media_id": media_id, "error_type": type(e).__name__}) + logging.error(f"Error processing chunks for media_id {media_id}: {e}") # Usage example: @@ -2995,46 +2549,48 @@ def update_media_table(db): # # Workflow Functions +# Workflow Functions def save_workflow_chat_to_db(chat_history, workflow_name, conversation_id=None): - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - if conversation_id is None: - # Create a new conversation - conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - cursor.execute(''' - INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) - VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - ''', (workflow_name, conversation_name)) - conversation_id = cursor.lastrowid - else: - # Update existing conversation - cursor.execute(''' - UPDATE ChatConversations - SET updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ''', (conversation_id,)) - - # Save messages - for user_msg, ai_msg in chat_history: - if user_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'user', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, user_msg)) - if ai_msg: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) - ''', (conversation_id, ai_msg)) - - conn.commit() - - return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" - except Exception as e: - logging.error(f"Error saving workflow chat to database: {str(e)}") - return None, f"Error saving chat to database: {str(e)}" + pass +# try: +# with db.get_connection() as conn: +# cursor = conn.cursor() +# +# if conversation_id is None: +# # Create a new conversation +# conversation_name = f"{workflow_name}_Workflow_{datetime.now().strftime('%Y%m%d_%H%M%S')}" +# cursor.execute(''' +# INSERT INTO ChatConversations (media_id, media_name, conversation_name, created_at, updated_at) +# VALUES (NULL, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) +# ''', (workflow_name, conversation_name)) +# conversation_id = cursor.lastrowid +# else: +# # Update existing conversation +# cursor.execute(''' +# UPDATE ChatConversations +# SET updated_at = CURRENT_TIMESTAMP +# WHERE id = ? +# ''', (conversation_id,)) +# +# # Save messages +# for user_msg, ai_msg in chat_history: +# if user_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'user', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, user_msg)) +# if ai_msg: +# cursor.execute(''' +# INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) +# VALUES (?, 'ai', ?, CURRENT_TIMESTAMP) +# ''', (conversation_id, ai_msg)) +# +# conn.commit() +# +# return conversation_id, f"Chat saved successfully! Conversation ID: {conversation_id}" +# except Exception as e: +# logging.error(f"Error saving workflow chat to database: {str(e)}") +# return None, f"Error saving chat to database: {str(e)}" def get_workflow_chat(conversation_id): diff --git a/App_Function_Libraries/Gradio_Related.py b/App_Function_Libraries/Gradio_Related.py index 37812ab8e76ae2af6ff83fa0b623b864f7f50631..afa84de9510c3c7a8a3c4ab2220a18f6e392b8e8 100644 --- a/App_Function_Libraries/Gradio_Related.py +++ b/App_Function_Libraries/Gradio_Related.py @@ -1,420 +1,600 @@ -# Gradio_Related.py -######################################### -# Gradio UI Functions Library -# I fucking hate Gradio. -# -######################################### -# -# Built-In Imports -import logging -import os -import webbrowser - -# -# Import 3rd-Party Libraries -import gradio as gr -# -# Local Imports -from App_Function_Libraries.DB.DB_Manager import get_db_config -from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab -from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab -from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab -from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \ - create_character_card_validation_tab, create_export_characters_tab -from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \ - create_multiple_character_chat_tab -from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \ - create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface -from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab -from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab -from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab -from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ - create_restore_backup_tab -from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \ - create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab -from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab -from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \ - create_delete_keyword_tab, create_export_keywords_tab -from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab -#from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab -#from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab -from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \ - create_media_edit_and_clone_tab, create_media_edit_tab -from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab -from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab -from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab -from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab -from App_Function_Libraries.Gradio_UI.Prompt_Suggestion_tab import create_prompt_suggestion_tab -from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_tab import create_rag_qa_chat_tab, create_rag_qa_notes_management_tab, \ - create_rag_qa_chat_management_tab -from App_Function_Libraries.Gradio_UI.Re_summarize_tab import create_resummary_tab -from App_Function_Libraries.Gradio_UI.Search_Tab import create_prompt_search_tab, \ - create_search_summaries_tab, create_search_tab -from App_Function_Libraries.Gradio_UI.RAG_Chat_tab import create_rag_tab -from App_Function_Libraries.Gradio_UI.Embeddings_tab import create_embeddings_tab, create_view_embeddings_tab, \ - create_purge_embeddings_tab -from App_Function_Libraries.Gradio_UI.Trash import create_view_trash_tab, create_empty_trash_tab, \ - create_delete_trash_tab, create_search_and_mark_trash_tab -from App_Function_Libraries.Gradio_UI.Utilities import create_utilities_yt_timestamp_tab, create_utilities_yt_audio_tab, \ - create_utilities_yt_video_tab -from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab -from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab -from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab -from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab -from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \ - create_view_all_with_versions_tab, create_viewing_tab -# -# Gradio UI Imports -from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab -#from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab -#from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab -# -####################################################################################################################### -# Function Definitions -# - - -# Disable Gradio Analytics -os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' - - -custom_prompt_input = None -server_mode = False -share_public = False -custom_prompt_summarize_bulleted_notes = (""" - You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath - - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded - - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] - """) -# -# End of globals -####################################################################################################################### -# -# Start of Video/Audio Transcription and Summarization Functions -# -# Functions: -# FIXME -# -# -################################################################################################################ -# Functions for Re-Summarization -# -# Functions: -# FIXME -# End of Re-Summarization Functions -# -############################################################################################################################################################################################################################ -# -# Explain/Summarize This Tab -# -# Functions: -# FIXME -# -# -############################################################################################################################################################################################################################ -# -# Transcript Comparison Tab -# -# Functions: -# FIXME -# -# -########################################################################################################################################################################################################################### -# -# Search Tab -# -# Functions: -# FIXME -# -# End of Search Tab Functions -# -############################################################################################################################################################################################################################## -# -# Llamafile Tab -# -# Functions: -# FIXME -# -# End of Llamafile Tab Functions -############################################################################################################################################################################################################################## -# -# Chat Interface Tab Functions -# -# Functions: -# FIXME -# -# -# End of Chat Interface Tab Functions -################################################################################################################################################################################################################################ -# -# Media Edit Tab Functions -# Functions: -# Fixme -# create_media_edit_tab(): -##### Trash Tab -# FIXME -# Functions: -# -# End of Media Edit Tab Functions -################################################################################################################ -# -# Import Items Tab Functions -# -# Functions: -#FIXME -# End of Import Items Tab Functions -################################################################################################################ -# -# Export Items Tab Functions -# -# Functions: -# FIXME -# -# -# End of Export Items Tab Functions -################################################################################################################ -# -# Keyword Management Tab Functions -# -# Functions: -# create_view_keywords_tab(): -# FIXME -# -# End of Keyword Management Tab Functions -################################################################################################################ -# -# Document Editing Tab Functions -# -# Functions: -# #FIXME -# -# -################################################################################################################ -# -# Utilities Tab Functions -# Functions: -# create_utilities_yt_video_tab(): -# #FIXME - -# -# End of Utilities Tab Functions -################################################################################################################ - -# FIXME - Prompt sample box -# -# # Sample data -# prompts_category_1 = [ -# "What are the key points discussed in the video?", -# "Summarize the main arguments made by the speaker.", -# "Describe the conclusions of the study presented." -# ] -# -# prompts_category_2 = [ -# "How does the proposed solution address the problem?", -# "What are the implications of the findings?", -# "Can you explain the theory behind the observed phenomenon?" -# ] -# -# all_prompts2 = prompts_category_1 + prompts_category_2 - - -def launch_ui(share_public=None, server_mode=False): - webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark') - share=share_public - css = """ - .result-box { - margin-bottom: 20px; - border: 1px solid #ddd; - padding: 10px; - } - .result-box.error { - border-color: #ff0000; - background-color: #ffeeee; - } - .transcription, .summary { - max-height: 800px; - overflow-y: auto; - border: 1px solid #eee; - padding: 10px; - margin-top: 10px; - } - """ - - with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface: - gr.HTML( - """ - - """ - ) - db_config = get_db_config() - db_type = db_config['type'] - gr.Markdown(f"# tl/dw: Your LLM-powered Research Multi-tool") - gr.Markdown(f"(Using {db_type.capitalize()} Database)") - with gr.Tabs(): - with gr.TabItem("Transcription / Summarization / Ingestion", id="ingestion-grouping", visible=True): - with gr.Tabs(): - create_video_transcription_tab() - #create_audio_processing_tab() - #create_podcast_tab() - #create_import_book_tab() - ##create_plain_text_import_tab() - #create_website_scraping_tab() - #create_pdf_ingestion_tab() - #create_pdf_ingestion_test_tab() - #create_resummary_tab() - #create_summarize_explain_tab() - #create_live_recording_tab() - #create_arxiv_tab() - - #with gr.TabItem("Text Search", id="text search", visible=True): - #create_search_tab() - #create_search_summaries_tab() - - #with gr.TabItem("RAG Chat/Search", id="RAG Chat Notes group", visible=True): - #create_rag_tab() - #create_rag_qa_chat_tab() - #create_rag_qa_notes_management_tab() - #create_rag_qa_chat_management_tab() - - #with gr.TabItem("Chat with an LLM", id="LLM Chat group", visible=True): - #create_chat_interface() - #create_chat_interface_stacked() - #create_chat_interface_multi_api() - #create_chat_interface_four() - #create_chat_with_llamafile_tab() - #create_chat_management_tab() - #chat_workflows_tab() - - - #with gr.TabItem("Character Chat", id="character chat group", visible=True): - #create_character_card_interaction_tab() - #create_character_chat_mgmt_tab() - #create_custom_character_card_tab() - #create_character_card_validation_tab() - #create_multiple_character_chat_tab() - #create_narrator_controlled_conversation_tab() - #create_export_characters_tab() - - - #with gr.TabItem("View DB Items", id="view db items group", visible=True): - # This one works - #create_view_all_with_versions_tab() - # This one is WIP - #create_viewing_tab() - #create_prompt_view_tab() - - - #with gr.TabItem("Prompts", id='view prompts group', visible=True): - #create_prompt_view_tab() - #create_prompt_search_tab() - #create_prompt_edit_tab() - #create_prompt_clone_tab() - #create_prompt_suggestion_tab() - - - #with gr.TabItem("Manage / Edit Existing Items", id="manage group", visible=True): - #create_media_edit_tab() - #create_manage_items_tab() - #create_media_edit_and_clone_tab() - # FIXME - #create_compare_transcripts_tab() - - - #with gr.TabItem("Embeddings Management", id="embeddings group", visible=True): - #create_embeddings_tab() - #create_view_embeddings_tab() - #create_purge_embeddings_tab() - - #with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): - #from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab - #create_document_feedback_tab() - #from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab - #create_grammar_style_check_tab() - #from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab - #create_tone_adjustment_tab() - #from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab - #create_creative_writing_tab() - #from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab - #create_mikupad_tab() - - - #with gr.TabItem("Keywords", id="keywords group", visible=True): - #create_view_keywords_tab() - #create_add_keyword_tab() - #create_delete_keyword_tab() - #create_export_keywords_tab() - - #with gr.TabItem("Import", id="import group", visible=True): - #create_import_item_tab() - #create_import_obsidian_vault_tab() - #create_import_single_prompt_tab() - #create_import_multiple_prompts_tab() - #create_mediawiki_import_tab() - #create_mediawiki_config_tab() - - #with gr.TabItem("Export", id="export group", visible=True): - #create_export_tab() - - #with gr.TabItem("Backup Management", id="backup group", visible=True): - #create_backup_tab() - #create_view_backups_tab() - #create_restore_backup_tab() - - #with gr.TabItem("Utilities", id="util group", visible=True): - #create_utilities_yt_video_tab() - #create_utilities_yt_audio_tab() - #create_utilities_yt_timestamp_tab() - - #with gr.TabItem("Local LLM", id="local llm group", visible=True): - #create_chat_with_llamafile_tab() - #create_ollama_tab() - #create_huggingface_tab() - - #with gr.TabItem("Trashcan", id="trashcan group", visible=True): - #create_search_and_mark_trash_tab() - #create_view_trash_tab() - #create_delete_trash_tab() - #create_empty_trash_tab() - - #with gr.TabItem("Evaluations", id="eval", visible=True): - #create_geval_tab() - #create_infinite_bench_tab() - # FIXME - #create_mmlu_pro_tab() - - #with gr.TabItem("Introduction/Help", id="introduction group", visible=True): - #create_introduction_tab() - - #with gr.TabItem("Config Editor", id="config group"): - #create_config_editor_tab() - - # Launch the interface - server_port_variable = 7860 - os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' - if share==True: - iface.launch(share=True) - elif server_mode and not share_public: - iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, ) - else: - try: - iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, ) - except Exception as e: - logging.error(f"Error launching interface: {str(e)}") +# Gradio_Related.py +######################################### +# Gradio UI Functions Library +# I fucking hate Gradio. +# +######################################### +# +# Built-In Imports +import logging +import os +import webbrowser +# +# Import 3rd-Party Libraries +import gradio as gr +# +# Local Imports +from App_Function_Libraries.DB.DB_Manager import get_db_config, backup_dir +from App_Function_Libraries.DB.RAG_QA_Chat_DB import create_tables +from App_Function_Libraries.Gradio_UI.Anki_tab import create_anki_validation_tab, create_anki_generator_tab +from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab +from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab +from App_Function_Libraries.Gradio_UI.Backup_RAG_Notes_Character_Chat_tab import create_database_management_interface +from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab +from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \ + create_character_card_validation_tab, create_export_characters_tab +from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \ + create_multiple_character_chat_tab +from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_interface_four, create_chat_interface_multi_api, \ + create_chat_interface_stacked, create_chat_interface +from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab +from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab +from App_Function_Libraries.Gradio_UI.Export_Functionality import create_rag_export_tab, create_export_tabs +#from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \ +# create_restore_backup_tab +from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \ + create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab, \ + create_conversation_import_tab +from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab +from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \ + create_delete_keyword_tab, create_export_keywords_tab, create_rag_qa_keywords_tab, create_character_keywords_tab, \ + create_meta_keywords_tab, create_prompt_keywords_tab +from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab +from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab +#from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab +from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \ + create_media_edit_and_clone_tab, create_media_edit_tab +from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab +from App_Function_Libraries.Gradio_UI.Mind_Map_tab import create_mindmap_tab +from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab +from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab +from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab +from App_Function_Libraries.Gradio_UI.Prompt_Suggestion_tab import create_prompt_suggestion_tab +from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_tab import create_rag_qa_chat_tab, create_rag_qa_notes_management_tab, \ + create_rag_qa_chat_management_tab +from App_Function_Libraries.Gradio_UI.Re_summarize_tab import create_resummary_tab +from App_Function_Libraries.Gradio_UI.Search_Tab import create_prompt_search_tab, \ + create_search_summaries_tab, create_search_tab +from App_Function_Libraries.Gradio_UI.RAG_Chat_tab import create_rag_tab +from App_Function_Libraries.Gradio_UI.Embeddings_tab import create_embeddings_tab, create_view_embeddings_tab, \ + create_purge_embeddings_tab +from App_Function_Libraries.Gradio_UI.Semantic_Scholar_tab import create_semantic_scholar_tab +from App_Function_Libraries.Gradio_UI.Trash import create_view_trash_tab, create_empty_trash_tab, \ + create_delete_trash_tab, create_search_and_mark_trash_tab +from App_Function_Libraries.Gradio_UI.Utilities import create_utilities_yt_timestamp_tab, create_utilities_yt_audio_tab, \ + create_utilities_yt_video_tab +from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab +from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab +from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab +from App_Function_Libraries.Gradio_UI.Workflows_tab import chat_workflows_tab +from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_view_all_mediadb_with_versions_tab, \ + create_viewing_mediadb_tab, create_view_all_rag_notes_tab, create_viewing_ragdb_tab, \ + create_mediadb_keyword_search_tab, create_ragdb_keyword_items_tab +from App_Function_Libraries.Gradio_UI.Prompts_tab import create_prompt_view_tab, create_prompts_export_tab +# +# Gradio UI Imports +from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab +from App_Function_Libraries.Gradio_UI.XML_Ingestion_Tab import create_xml_import_tab +#from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab +from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab +from App_Function_Libraries.Utils.Utils import load_and_log_configs + +# +####################################################################################################################### +# Function Definitions +# + + +# Disable Gradio Analytics +os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' + + +custom_prompt_input = None +server_mode = False +share_public = False +custom_prompt_summarize_bulleted_notes = (""" + You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """) +# +# End of globals +####################################################################################################################### +# +# Start of Video/Audio Transcription and Summarization Functions +# +# Functions: +# FIXME +# +# +################################################################################################################ +# Functions for Re-Summarization +# +# Functions: +# FIXME +# End of Re-Summarization Functions +# +############################################################################################################################################################################################################################ +# +# Explain/Summarize This Tab +# +# Functions: +# FIXME +# +# +############################################################################################################################################################################################################################ +# +# Transcript Comparison Tab +# +# Functions: +# FIXME +# +# +########################################################################################################################################################################################################################### +# +# Search Tab +# +# Functions: +# FIXME +# +# End of Search Tab Functions +# +############################################################################################################################################################################################################################## +# +# Llamafile Tab +# +# Functions: +# FIXME +# +# End of Llamafile Tab Functions +############################################################################################################################################################################################################################## +# +# Chat Interface Tab Functions +# +# Functions: +# FIXME +# +# +# End of Chat Interface Tab Functions +################################################################################################################################################################################################################################ +# +# Media Edit Tab Functions +# Functions: +# Fixme +# create_media_edit_tab(): +##### Trash Tab +# FIXME +# Functions: +# +# End of Media Edit Tab Functions +################################################################################################################ +# +# Import Items Tab Functions +# +# Functions: +#FIXME +# End of Import Items Tab Functions +################################################################################################################ +# +# Export Items Tab Functions +# +# Functions: +# FIXME +# +# +# End of Export Items Tab Functions +################################################################################################################ +# +# Keyword Management Tab Functions +# +# Functions: +# create_view_keywords_tab(): +# FIXME +# +# End of Keyword Management Tab Functions +################################################################################################################ +# +# Document Editing Tab Functions +# +# Functions: +# #FIXME +# +# +################################################################################################################ +# +# Utilities Tab Functions +# Functions: +# create_utilities_yt_video_tab(): +# #FIXME + +# +# End of Utilities Tab Functions +################################################################################################################ + +# FIXME - Prompt sample box +# +# # Sample data +# prompts_category_1 = [ +# "What are the key points discussed in the video?", +# "Summarize the main arguments made by the speaker.", +# "Describe the conclusions of the study presented." +# ] +# +# prompts_category_2 = [ +# "How does the proposed solution address the problem?", +# "What are the implications of the findings?", +# "Can you explain the theory behind the observed phenomenon?" +# ] +# +# all_prompts2 = prompts_category_1 + prompts_category_2 + + + +####################################################################################################################### +# +# Migration Script +import sqlite3 +import uuid +import logging +import os +from datetime import datetime +import shutil + +# def migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path): +# # Check if migration is needed +# if not os.path.exists(media_db_path): +# logging.info("Media DB does not exist. No migration needed.") +# return +# +# # Optional: Check if migration has already been completed +# migration_flag = os.path.join(os.path.dirname(rag_chat_db_path), 'migration_completed.flag') +# if os.path.exists(migration_flag): +# logging.info("Migration already completed. Skipping migration.") +# return +# +# # Backup databases +# backup_database(media_db_path) +# backup_database(rag_chat_db_path) +# +# # Connect to both databases +# try: +# media_conn = sqlite3.connect(media_db_path) +# rag_conn = sqlite3.connect(rag_chat_db_path) +# +# # Enable foreign key support +# media_conn.execute('PRAGMA foreign_keys = ON;') +# rag_conn.execute('PRAGMA foreign_keys = ON;') +# +# media_cursor = media_conn.cursor() +# rag_cursor = rag_conn.cursor() +# +# # Begin transaction +# rag_conn.execute('BEGIN TRANSACTION;') +# +# # Extract conversations from media DB +# media_cursor.execute(''' +# SELECT id, media_id, media_name, conversation_name, created_at, updated_at +# FROM ChatConversations +# ''') +# conversations = media_cursor.fetchall() +# +# for conv in conversations: +# old_conv_id, media_id, media_name, conversation_name, created_at, updated_at = conv +# +# # Convert timestamps if necessary +# created_at = parse_timestamp(created_at) +# updated_at = parse_timestamp(updated_at) +# +# # Generate a new conversation_id +# conversation_id = str(uuid.uuid4()) +# title = conversation_name or (f"{media_name}-chat" if media_name else "Untitled Conversation") +# +# # Insert into conversation_metadata +# rag_cursor.execute(''' +# INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title, media_id) +# VALUES (?, ?, ?, ?, ?) +# ''', (conversation_id, created_at, updated_at, title, media_id)) +# +# # Extract messages from media DB +# media_cursor.execute(''' +# SELECT sender, message, timestamp +# FROM ChatMessages +# WHERE conversation_id = ? +# ORDER BY timestamp ASC +# ''', (old_conv_id,)) +# messages = media_cursor.fetchall() +# +# for msg in messages: +# sender, content, timestamp = msg +# +# # Convert timestamp if necessary +# timestamp = parse_timestamp(timestamp) +# +# role = sender # Assuming 'sender' is 'user' or 'ai' +# +# # Insert message into rag_qa_chats +# rag_cursor.execute(''' +# INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) +# VALUES (?, ?, ?, ?) +# ''', (conversation_id, timestamp, role, content)) +# +# # Commit transaction +# rag_conn.commit() +# logging.info("Migration completed successfully.") +# +# # Mark migration as complete +# with open(migration_flag, 'w') as f: +# f.write('Migration completed on ' + datetime.now().isoformat()) +# +# except Exception as e: +# # Rollback transaction in case of error +# rag_conn.rollback() +# logging.error(f"Error during migration: {e}") +# raise +# finally: +# media_conn.close() +# rag_conn.close() + +def backup_database(db_path): + backup_path = db_path + '.backup' + if not os.path.exists(backup_path): + shutil.copyfile(db_path, backup_path) + logging.info(f"Database backed up to {backup_path}") + else: + logging.info(f"Backup already exists at {backup_path}") + +def parse_timestamp(timestamp_value): + """ + Parses the timestamp from the old database and converts it to a standard format. + Adjust this function based on the actual format of your timestamps. + """ + try: + # Attempt to parse ISO format + return datetime.fromisoformat(timestamp_value).isoformat() + except ValueError: + # Handle other timestamp formats if necessary + # For example, if timestamps are in Unix epoch format + try: + timestamp_float = float(timestamp_value) + return datetime.fromtimestamp(timestamp_float).isoformat() + except ValueError: + # Default to current time if parsing fails + logging.warning(f"Unable to parse timestamp '{timestamp_value}', using current time.") + return datetime.now().isoformat() + +# +# End of Migration Script +####################################################################################################################### + + +####################################################################################################################### +# +# Launch UI Function +def launch_ui(share_public=None, server_mode=False): + webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark') + share=share_public + css = """ + .result-box { + margin-bottom: 20px; + border: 1px solid #ddd; + padding: 10px; + } + .result-box.error { + border-color: #ff0000; + background-color: #ffeeee; + } + .transcription, .summary { + max-height: 800px; + overflow-y: auto; + border: 1px solid #eee; + padding: 10px; + margin-top: 10px; + } + """ + + config = load_and_log_configs() + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + rag_chat_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + # Initialize the RAG Chat DB (create tables and update schema) + create_tables() + + # Migrate data from the media DB to the RAG Chat DB + #migrate_media_db_to_rag_chat_db(media_db_path, rag_chat_db_path) + + + with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface: + gr.HTML( + """ + + """ + ) + db_config = get_db_config() + db_type = db_config['type'] + gr.Markdown(f"# tl/dw: Your LLM-powered Research Multi-tool") + gr.Markdown(f"(Using {db_type.capitalize()} Database)") + with gr.Tabs(): + with gr.TabItem("Transcribe / Analyze / Ingestion", id="ingestion-grouping", visible=True): + with gr.Tabs(): + create_video_transcription_tab() + create_audio_processing_tab() + create_podcast_tab() + create_import_book_tab() + create_plain_text_import_tab() + create_xml_import_tab() + create_website_scraping_tab() + create_pdf_ingestion_tab() + create_pdf_ingestion_test_tab() + create_resummary_tab() + create_summarize_explain_tab() + create_live_recording_tab() + create_arxiv_tab() + create_semantic_scholar_tab() + + with gr.TabItem("RAG Chat/Search", id="RAG Chat Notes group", visible=True): + create_rag_tab() + create_rag_qa_chat_tab() + create_rag_qa_notes_management_tab() + create_rag_qa_chat_management_tab() + + with gr.TabItem("Chat with an LLM", id="LLM Chat group", visible=True): + create_chat_interface() + create_chat_interface_stacked() + create_chat_interface_multi_api() + create_chat_interface_four() + chat_workflows_tab() + + with gr.TabItem("Character Chat", id="character chat group", visible=True): + create_character_card_interaction_tab() + create_character_chat_mgmt_tab() + create_custom_character_card_tab() + create_character_card_validation_tab() + create_multiple_character_chat_tab() + create_narrator_controlled_conversation_tab() + create_export_characters_tab() + + with gr.TabItem("Writing Tools", id="writing_tools group", visible=True): + from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab + create_document_feedback_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab + create_grammar_style_check_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab + create_tone_adjustment_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab + create_creative_writing_tab() + from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab + create_mikupad_tab() + + with gr.TabItem("Search/View DB Items", id="view db items group", visible=True): + create_search_tab() + create_search_summaries_tab() + create_view_all_mediadb_with_versions_tab() + create_viewing_mediadb_tab() + create_mediadb_keyword_search_tab() + create_view_all_rag_notes_tab() + create_viewing_ragdb_tab() + create_ragdb_keyword_items_tab() + + with gr.TabItem("Prompts", id='view prompts group', visible=True): + with gr.Tabs(): + create_prompt_view_tab() + create_prompt_search_tab() + create_prompt_edit_tab() + create_prompt_clone_tab() + create_prompt_suggestion_tab() + create_prompts_export_tab() + + with gr.TabItem("Manage Media DB Items", id="manage group", visible=True): + create_media_edit_tab() + create_manage_items_tab() + create_media_edit_and_clone_tab() + + with gr.TabItem("Embeddings Management", id="embeddings group", visible=True): + create_embeddings_tab() + create_view_embeddings_tab() + create_purge_embeddings_tab() + + with gr.TabItem("Keywords", id="keywords group", visible=True): + create_view_keywords_tab() + create_add_keyword_tab() + create_delete_keyword_tab() + create_export_keywords_tab() + create_character_keywords_tab() + create_rag_qa_keywords_tab() + create_meta_keywords_tab() + create_prompt_keywords_tab() + + with gr.TabItem("Import", id="import group", visible=True): + create_import_item_tab() + create_import_obsidian_vault_tab() + create_import_single_prompt_tab() + create_import_multiple_prompts_tab() + create_mediawiki_import_tab() + create_mediawiki_config_tab() + create_conversation_import_tab() + + with gr.TabItem("Export", id="export group", visible=True): + create_export_tabs() + + + with gr.TabItem("Database Management", id="database_management_group", visible=True): + create_database_management_interface( + media_db_config={ + 'db_path': media_db_path, + 'backup_dir': backup_dir + }, + rag_db_config={ + 'db_path': rag_chat_db_path, + 'backup_dir': backup_dir + }, + char_db_config={ + 'db_path': character_chat_db_path, + 'backup_dir': backup_dir + } + ) + + with gr.TabItem("Utilities", id="util group", visible=True): + create_mindmap_tab() + create_utilities_yt_video_tab() + create_utilities_yt_audio_tab() + create_utilities_yt_timestamp_tab() + + with gr.TabItem("Anki Deck Creation/Validation", id="anki group", visible=True): + create_anki_generator_tab() + create_anki_validation_tab() + + with gr.TabItem("Local LLM", id="local llm group", visible=True): + create_chat_with_llamafile_tab() + create_ollama_tab() + #create_huggingface_tab() + + with gr.TabItem("Trashcan", id="trashcan group", visible=True): + create_search_and_mark_trash_tab() + create_view_trash_tab() + create_delete_trash_tab() + create_empty_trash_tab() + + with gr.TabItem("Evaluations", id="eval", visible=True): + create_geval_tab() + create_infinite_bench_tab() + # FIXME + #create_mmlu_pro_tab() + + with gr.TabItem("Introduction/Help", id="introduction group", visible=True): + create_introduction_tab() + + with gr.TabItem("Config Editor", id="config group"): + create_config_editor_tab() + + # Launch the interface + server_port_variable = 7860 + os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' + if share==True: + iface.launch(share=True) + elif server_mode and not share_public: + iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, ) + else: + try: + iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, ) + except Exception as e: + logging.error(f"Error launching interface: {str(e)}") diff --git a/App_Function_Libraries/Gradio_UI/Anki_tab.py b/App_Function_Libraries/Gradio_UI/Anki_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2490266af71fc957e9e1709746c184e104ac1a --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Anki_tab.py @@ -0,0 +1,921 @@ +# Anki_Validation_tab.py +# Description: Gradio functions for the Anki Validation tab +# +# Imports +import json +import logging +import os +import tempfile +from typing import Optional, Tuple, List, Dict +# +# External Imports +import genanki +import gradio as gr +# +# Local Imports +from App_Function_Libraries.Chat.Chat_Functions import approximate_token_count, update_chat_content, save_chat_history, \ + save_chat_history_to_db_wrapper +from App_Function_Libraries.DB.DB_Manager import list_prompts +from App_Function_Libraries.Gradio_UI.Chat_ui import update_dropdown_multiple, chat_wrapper, update_selected_parts, \ + search_conversations, regenerate_last_message, load_conversation, debug_output +from App_Function_Libraries.Third_Party.Anki import sanitize_html, generate_card_choices, \ + export_cards, load_card_for_editing, handle_file_upload, \ + validate_for_ui, update_card_with_validation, update_card_choices, enhanced_file_upload, \ + handle_validation +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +# +############################################################################################################ +# +# Functions: + +def create_anki_validation_tab(): + with gr.TabItem("Anki Flashcard Validation", visible=True): + gr.Markdown("# Anki Flashcard Validation and Editor") + + # State variables for internal tracking + current_card_data = gr.State({}) + preview_update_flag = gr.State(False) + + with gr.Row(): + # Left Column: Input and Validation + with gr.Column(scale=1): + gr.Markdown("## Import or Create Flashcards") + + input_type = gr.Radio( + choices=["JSON", "APKG"], + label="Input Type", + value="JSON" + ) + + with gr.Group() as json_input_group: + flashcard_input = gr.TextArea( + label="Enter Flashcards (JSON format)", + placeholder='''{ + "cards": [ + { + "id": "CARD_001", + "type": "basic", + "front": "What is the capital of France?", + "back": "Paris", + "tags": ["geography", "europe"], + "note": "Remember: City of Light" + } + ] +}''', + lines=10 + ) + + import_json = gr.File( + label="Or Import JSON File", + file_types=[".json"] + ) + + with gr.Group(visible=False) as apkg_input_group: + import_apkg = gr.File( + label="Import APKG File", + file_types=[".apkg"] + ) + deck_info = gr.JSON( + label="Deck Information", + visible=False + ) + + validate_button = gr.Button("Validate Flashcards") + + # Right Column: Validation Results and Editor + with gr.Column(scale=1): + gr.Markdown("## Validation Results") + validation_status = gr.Markdown("") + + with gr.Accordion("Validation Rules", open=False): + gr.Markdown(""" + ### Required Fields: + - Unique ID + - Card Type (basic, cloze, reverse) + - Front content + - Back content + - At least one tag + + ### Content Rules: + - No empty fields + - Front side should be a clear question/prompt + - Back side should contain complete answer + - Cloze deletions must have valid syntax + - No duplicate IDs + + ### Image Rules: + - Valid image tags + - Supported formats (JPG, PNG, GIF) + - Base64 encoded or valid URL + + ### APKG-specific Rules: + - Valid SQLite database structure + - Media files properly referenced + - Note types match Anki standards + - Card templates are well-formed + """) + + with gr.Row(): + # Card Editor + gr.Markdown("## Card Editor") + with gr.Row(): + with gr.Column(scale=1): + with gr.Accordion("Edit Individual Cards", open=True): + card_selector = gr.Dropdown( + label="Select Card to Edit", + choices=[], + interactive=True + ) + + card_type = gr.Radio( + choices=["basic", "cloze", "reverse"], + label="Card Type", + value="basic" + ) + + # Front content with preview + with gr.Group(): + gr.Markdown("### Front Content") + front_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + front_preview = gr.HTML( + label="Preview" + ) + + # Back content with preview + with gr.Group(): + gr.Markdown("### Back Content") + back_content = gr.TextArea( + label="Content (HTML supported)", + lines=3 + ) + back_preview = gr.HTML( + label="Preview" + ) + + tags_input = gr.TextArea( + label="Tags (comma-separated)", + lines=1 + ) + + notes_input = gr.TextArea( + label="Additional Notes", + lines=2 + ) + + with gr.Row(): + update_card_button = gr.Button("Update Card") + delete_card_button = gr.Button("Delete Card", variant="stop") + + with gr.Row(): + with gr.Column(scale=1): + # Export Options + gr.Markdown("## Export Options") + export_format = gr.Radio( + choices=["Anki CSV", "JSON", "Plain Text"], + label="Export Format", + value="Anki CSV" + ) + export_button = gr.Button("Export Valid Cards") + export_file = gr.File(label="Download Validated Cards") + export_status = gr.Markdown("") + with gr.Column(scale=1): + gr.Markdown("## Export Instructions") + gr.Markdown(""" + ### Anki CSV Format: + - Front, Back, Tags, Type, Note + - Use for importing into Anki + - Images preserved as HTML + + ### JSON Format: + - JSON array of cards + - Images as base64 or URLs + - Use for custom processing + + ### Plain Text Format: + - Question and Answer pairs + - Images represented as [IMG] placeholder + - Use for manual review + """) + + def update_preview(content): + """Update preview with sanitized content.""" + if not content: + return "" + return sanitize_html(content) + + # Event handlers + def validation_chain(content: str) -> Tuple[str, List[str]]: + """Combined validation and card choice update.""" + validation_message = validate_for_ui(content) + card_choices = update_card_choices(content) + return validation_message, card_choices + + def delete_card(card_selection, current_content): + """Delete selected card and return updated content.""" + if not card_selection or not current_content: + return current_content, "No card selected", [] + + try: + data = json.loads(current_content) + selected_id = card_selection.split(" - ")[0] + + data['cards'] = [card for card in data['cards'] if card['id'] != selected_id] + new_content = json.dumps(data, indent=2) + + return ( + new_content, + "Card deleted successfully!", + generate_card_choices(new_content) + ) + + except Exception as e: + return current_content, f"Error deleting card: {str(e)}", [] + + def process_validation_result(is_valid, message): + """Process validation result into a formatted markdown string.""" + if is_valid: + return f"✅ {message}" + else: + return f"❌ {message}" + + # Register event handlers + input_type.change( + fn=lambda t: ( + gr.update(visible=t == "JSON"), + gr.update(visible=t == "APKG"), + gr.update(visible=t == "APKG") + ), + inputs=[input_type], + outputs=[json_input_group, apkg_input_group, deck_info] + ) + + # File upload handlers + import_json.upload( + fn=handle_file_upload, + inputs=[import_json, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + import_apkg.upload( + fn=enhanced_file_upload, + inputs=[import_apkg, input_type], + outputs=[ + flashcard_input, + deck_info, + validation_status, + card_selector + ] + ) + + # Validation handler + validate_button.click( + fn=lambda content, input_format: ( + handle_validation(content, input_format), + generate_card_choices(content) if content else [] + ), + inputs=[flashcard_input, input_type], + outputs=[validation_status, card_selector] + ) + + # Card editing handlers + # Card selector change event + card_selector.change( + fn=load_card_for_editing, + inputs=[card_selector, flashcard_input], + outputs=[ + card_type, + front_content, + back_content, + tags_input, + notes_input, + front_preview, + back_preview + ] + ) + + # Live preview updates + front_content.change( + fn=update_preview, + inputs=[front_content], + outputs=[front_preview] + ) + + back_content.change( + fn=update_preview, + inputs=[back_content], + outputs=[back_preview] + ) + + # Card update handler + update_card_button.click( + fn=update_card_with_validation, + inputs=[ + flashcard_input, + card_selector, + card_type, + front_content, + back_content, + tags_input, + notes_input + ], + outputs=[ + flashcard_input, + validation_status, + card_selector + ] + ) + + # Delete card handler + delete_card_button.click( + fn=delete_card, + inputs=[card_selector, flashcard_input], + outputs=[flashcard_input, validation_status, card_selector] + ) + + # Export handler + export_button.click( + fn=export_cards, + inputs=[flashcard_input, export_format], + outputs=[export_status, export_file] + ) + + return ( + flashcard_input, + import_json, + import_apkg, + validate_button, + validation_status, + card_selector, + card_type, + front_content, + back_content, + front_preview, + back_preview, + tags_input, + notes_input, + update_card_button, + delete_card_button, + export_format, + export_button, + export_file, + export_status, + deck_info + ) + + +def create_anki_generator_tab(): + with gr.TabItem("Anki Deck Generator", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + custom_css = """ + .chatbot-container .message-wrap .message { + font-size: 14px !important; + } + """ + with gr.TabItem("LLM Chat & Anki Deck Creation", visible=True): + gr.Markdown("# Chat with an LLM to help you come up with Questions/Answers for an Anki Deck") + chat_history = gr.State([]) + media_content = gr.State({}) + selected_parts = gr.State([]) + conversation_id = gr.State(None) + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + + with gr.Row(): + with gr.Column(scale=1): + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) + search_button = gr.Button("Search") + items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) + item_mapping = gr.State({}) + with gr.Row(): + use_content = gr.Checkbox(label="Use Content") + use_summary = gr.Checkbox(label="Use Summary") + use_prompt = gr.Checkbox(label="Use Prompt") + save_conversation = gr.Checkbox(label="Save Conversation", value=False, visible=True) + with gr.Row(): + temperature = gr.Slider(label="Temperature", minimum=0.00, maximum=1.0, step=0.05, value=0.7) + with gr.Row(): + conversation_search = gr.Textbox(label="Search Conversations") + with gr.Row(): + search_conversations_btn = gr.Button("Search Conversations") + with gr.Row(): + previous_conversations = gr.Dropdown(label="Select Conversation", choices=[], interactive=True) + with gr.Row(): + load_conversations_btn = gr.Button("Load Selected Conversation") + + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) + api_key = gr.Textbox(label="API Key (if required)", type="password") + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", + value=False, + visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a Pre-set Prompt", + value=False, + visible=True) + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + next_prompt_page = gr.Button("Next") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts + ) + user_prompt = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) + system_prompt_input = gr.Textbox(label="System Prompt", + value="You are a helpful AI assitant", + lines=3, + visible=False) + with gr.Column(scale=2): + chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") + msg = gr.Textbox(label="Enter your message") + submit = gr.Button("Submit") + regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) + clear_chat_button = gr.Button("Clear Chat") + + chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") + save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) + save_chat_history_as_file = gr.Button("Save Chat History as File") + download_file = gr.File(label="Download Chat History") + + search_button.click( + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], + outputs=[items_output, item_mapping] + ) + + def update_prompt_visibility(custom_prompt_checked, preset_prompt_checked): + user_prompt_visible = custom_prompt_checked + system_prompt_visible = custom_prompt_checked + preset_prompt_visible = preset_prompt_checked + preset_prompt_controls_visible = preset_prompt_checked + return ( + gr.update(visible=user_prompt_visible, interactive=user_prompt_visible), + gr.update(visible=system_prompt_visible, interactive=system_prompt_visible), + gr.update(visible=preset_prompt_visible, interactive=preset_prompt_visible), + gr.update(visible=preset_prompt_controls_visible) + ) + + def update_prompt_page(direction, current_page_val): + new_page = current_page_val + direction + if new_page < 1: + new_page = 1 + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + if new_page > total_pages: + new_page = total_pages + prompts, total_pages, _ = list_prompts(page=new_page, per_page=20) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def clear_chat(): + return [], None # Return empty list for chatbot and None for conversation_id + + custom_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + preset_prompt_checkbox.change( + update_prompt_visibility, + inputs=[custom_prompt_checkbox, preset_prompt_checkbox], + outputs=[user_prompt, system_prompt_input, preset_prompt, preset_prompt_controls] + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + submit.click( + chat_wrapper, + inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, + conversation_id, + save_conversation, temperature, system_prompt_input], + outputs=[msg, chatbot, conversation_id] + ).then( # Clear the message box after submission + lambda x: gr.update(value=""), + inputs=[chatbot], + outputs=[msg] + ).then( # Clear the user prompt after the first message + lambda: (gr.update(value=""), gr.update(value="")), + outputs=[user_prompt, system_prompt_input] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + + + clear_chat_button.click( + clear_chat, + outputs=[chatbot, conversation_id] + ) + + items_output.change( + update_chat_content, + inputs=[items_output, use_content, use_summary, use_prompt, item_mapping], + outputs=[media_content, selected_parts] + ) + + use_content.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_summary.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + use_prompt.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], + outputs=[selected_parts]) + items_output.change(debug_output, inputs=[media_content, selected_parts], outputs=[]) + + search_conversations_btn.click( + search_conversations, + inputs=[conversation_search], + outputs=[previous_conversations] + ) + + load_conversations_btn.click( + clear_chat, + outputs=[chatbot, chat_history] + ).then( + load_conversation, + inputs=[previous_conversations], + outputs=[chatbot, conversation_id] + ) + + previous_conversations.change( + load_conversation, + inputs=[previous_conversations], + outputs=[chat_history] + ) + + save_chat_history_as_file.click( + save_chat_history, + inputs=[chatbot, conversation_id], + outputs=[download_file] + ) + + save_chat_history_to_db.click( + save_chat_history_to_db_wrapper, + inputs=[chatbot, conversation_id, media_content, chat_media_name], + outputs=[conversation_id, gr.Textbox(label="Save Status")] + ) + + regenerate_button.click( + regenerate_last_message, + inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, + system_prompt_input], + outputs=[chatbot, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] + ) + gr.Markdown("# Create Anki Deck") + + with gr.Row(): + # Left Column: Deck Settings + with gr.Column(scale=1): + gr.Markdown("## Deck Settings") + deck_name = gr.Textbox( + label="Deck Name", + placeholder="My Study Deck", + value="My Study Deck" + ) + + deck_description = gr.Textbox( + label="Deck Description", + placeholder="Description of your deck", + lines=2 + ) + + note_type = gr.Radio( + choices=["Basic", "Basic (and reversed)", "Cloze"], + label="Note Type", + value="Basic" + ) + + # Card Fields based on note type + with gr.Group() as basic_fields: + front_template = gr.Textbox( + label="Front Template (HTML)", + value="{{Front}}", + lines=3 + ) + back_template = gr.Textbox( + label="Back Template (HTML)", + value="{{FrontSide}}
{{Back}}", + lines=3 + ) + + with gr.Group() as cloze_fields: + cloze_template = gr.Textbox( + label="Cloze Template (HTML)", + value="{{cloze:Text}}", + lines=3, + visible=False + ) + + css_styling = gr.Textbox( + label="Card Styling (CSS)", + value=".card {\n font-family: arial;\n font-size: 20px;\n text-align: center;\n color: black;\n background-color: white;\n}\n\n.cloze {\n font-weight: bold;\n color: blue;\n}", + lines=5 + ) + + # Right Column: Card Creation + with gr.Column(scale=1): + gr.Markdown("## Add Cards") + + with gr.Group() as basic_input: + front_content = gr.TextArea( + label="Front Content", + placeholder="Question or prompt", + lines=3 + ) + back_content = gr.TextArea( + label="Back Content", + placeholder="Answer", + lines=3 + ) + + with gr.Group() as cloze_input: + cloze_content = gr.TextArea( + label="Cloze Content", + placeholder="Text with {{c1::cloze}} deletions", + lines=3, + visible=False + ) + + tags_input = gr.TextArea( + label="Tags (comma-separated)", + placeholder="tag1, tag2, tag3", + lines=1 + ) + + add_card_btn = gr.Button("Add Card") + + cards_list = gr.JSON( + label="Cards in Deck", + value={"cards": []} + ) + + clear_cards_btn = gr.Button("Clear All Cards", variant="stop") + + with gr.Row(): + generate_deck_btn = gr.Button("Generate Deck", variant="primary") + download_deck = gr.File(label="Download Deck") + generation_status = gr.Markdown("") + + def update_note_type_fields(note_type: str): + if note_type == "Cloze": + return { + basic_input: gr.update(visible=False), + cloze_input: gr.update(visible=True), + basic_fields: gr.update(visible=False), + cloze_fields: gr.update(visible=True) + } + else: + return { + basic_input: gr.update(visible=True), + cloze_input: gr.update(visible=False), + basic_fields: gr.update(visible=True), + cloze_fields: gr.update(visible=False) + } + + def add_card(note_type: str, front: str, back: str, cloze: str, tags: str, current_cards: Dict[str, List]): + if not current_cards: + current_cards = {"cards": []} + + cards_data = current_cards["cards"] + + # Process tags + card_tags = [tag.strip() for tag in tags.split(',') if tag.strip()] + + new_card = { + "id": f"CARD_{len(cards_data) + 1}", + "tags": card_tags + } + + if note_type == "Cloze": + if not cloze or "{{c" not in cloze: + return current_cards, "❌ Invalid cloze format. Use {{c1::text}} syntax." + new_card.update({ + "type": "cloze", + "content": cloze + }) + else: + if not front or not back: + return current_cards, "❌ Both front and back content are required." + new_card.update({ + "type": "basic", + "front": front, + "back": back, + "is_reverse": note_type == "Basic (and reversed)" + }) + + cards_data.append(new_card) + return {"cards": cards_data}, "✅ Card added successfully!" + + def clear_cards() -> Tuple[Dict[str, List], str]: + return {"cards": []}, "✅ All cards cleared!" + + def generate_anki_deck( + deck_name: str, + deck_description: str, + note_type: str, + front_template: str, + back_template: str, + cloze_template: str, + css: str, + cards_data: Dict[str, List] + ) -> Tuple[Optional[str], str]: + try: + if not cards_data or not cards_data.get("cards"): + return None, "❌ No cards to generate deck from!" + + # Create model based on note type + if note_type == "Cloze": + model = genanki.Model( + 1483883320, # Random model ID + 'Cloze Model', + fields=[ + {'name': 'Text'}, + {'name': 'Back Extra'} + ], + templates=[{ + 'name': 'Cloze Card', + 'qfmt': cloze_template, + 'afmt': cloze_template + '

{{Back Extra}}' + }], + css=css, + # FIXME CLOZE DOESNT EXIST + model_type=1 + ) + else: + templates = [{ + 'name': 'Card 1', + 'qfmt': front_template, + 'afmt': back_template + }] + + if note_type == "Basic (and reversed)": + templates.append({ + 'name': 'Card 2', + 'qfmt': '{{Back}}', + 'afmt': '{{FrontSide}}
{{Front}}' + }) + + model = genanki.Model( + 1607392319, # Random model ID + 'Basic Model', + fields=[ + {'name': 'Front'}, + {'name': 'Back'} + ], + templates=templates, + css=css + ) + + # Create deck + deck = genanki.Deck( + 2059400110, # Random deck ID + deck_name, + description=deck_description + ) + + # Add cards to deck + for card in cards_data["cards"]: + if card["type"] == "cloze": + note = genanki.Note( + model=model, + fields=[card["content"], ""], + tags=card["tags"] + ) + else: + note = genanki.Note( + model=model, + fields=[card["front"], card["back"]], + tags=card["tags"] + ) + deck.add_note(note) + + # Save deck to temporary file + temp_dir = tempfile.mkdtemp() + deck_path = os.path.join(temp_dir, f"{deck_name}.apkg") + genanki.Package(deck).write_to_file(deck_path) + + return deck_path, "✅ Deck generated successfully!" + + except Exception as e: + return None, f"❌ Error generating deck: {str(e)}" + + # Register event handlers + note_type.change( + fn=update_note_type_fields, + inputs=[note_type], + outputs=[basic_input, cloze_input, basic_fields, cloze_fields] + ) + + add_card_btn.click( + fn=add_card, + inputs=[ + note_type, + front_content, + back_content, + cloze_content, + tags_input, + cards_list + ], + outputs=[cards_list, generation_status] + ) + + clear_cards_btn.click( + fn=clear_cards, + inputs=[], + outputs=[cards_list, generation_status] + ) + + generate_deck_btn.click( + fn=generate_anki_deck, + inputs=[ + deck_name, + deck_description, + note_type, + front_template, + back_template, + cloze_template, + css_styling, + cards_list + ], + outputs=[download_deck, generation_status] + ) + + + return ( + deck_name, + deck_description, + note_type, + front_template, + back_template, + cloze_template, + css_styling, + front_content, + back_content, + cloze_content, + tags_input, + cards_list, + add_card_btn, + clear_cards_btn, + generate_deck_btn, + download_deck, + generation_status + ) + +# +# End of Anki_Validation_tab.py +############################################################################################################ diff --git a/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py b/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py index 3eee842b05c12075fcb23f9ee7b623c6f768604a..892aa622f3a877074a68f46e91a743dd66c8737f 100644 --- a/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py @@ -2,16 +2,18 @@ # Description: Gradio UI for ingesting audio files into the database # # Imports +import logging # # External Imports import gradio as gr # # Local Imports from App_Function_Libraries.Audio.Audio_Files import process_audio_files -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models -from App_Function_Libraries.Utils.Utils import cleanup_temp_files +from App_Function_Libraries.Utils.Utils import cleanup_temp_files, default_api_endpoint, global_api_endpoints, \ + format_api_name # Import metrics logging from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram from App_Function_Libraries.Metrics.logger_config import logger @@ -22,6 +24,18 @@ from App_Function_Libraries.Metrics.logger_config import logger def create_audio_processing_tab(): with gr.TabItem("Audio File Transcription + Summarization", visible=True): gr.Markdown("# Transcribe & Summarize Audio Files from URLs or Local Files!") + # Get and validate default value + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.Row(): with gr.Column(): audio_url_input = gr.Textbox(label="Audio File URL(s)", placeholder="Enter the URL(s) of the audio file(s), one per line") @@ -46,54 +60,133 @@ def create_audio_processing_tab(): keep_timestamps_input = gr.Checkbox(label="Keep Timestamps", value=True) with gr.Row(): - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] -""", - lines=3, - visible=False) + system_prompt_input = gr.Textbox( + label="System Prompt", + value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhere to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """, + lines=3, + visible=False + ) custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -103,15 +196,14 @@ def create_audio_processing_tab(): preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[custom_prompt_input, system_prompt_input] ) - + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM","ollama", "HuggingFace", "Custom-OpenAI-API"], - value=None, - label="API for Summarization (Optional)" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" ) api_key_input = gr.Textbox(label="API Key (if required)", placeholder="Enter your API key here", type="password") custom_keywords_input = gr.Textbox(label="Custom Keywords", placeholder="Enter custom keywords, comma-separated") diff --git a/App_Function_Libraries/Gradio_UI/Backup_Functionality.py b/App_Function_Libraries/Gradio_UI/Backup_Functionality.py index c4bc198ec7ea0b811e7d60bf0756f305bc4d3951..2c31657cfff44919f0e53e96a1e5cda24f527147 100644 --- a/App_Function_Libraries/Gradio_UI/Backup_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Backup_Functionality.py @@ -14,7 +14,7 @@ from App_Function_Libraries.DB.DB_Manager import create_automated_backup, db_pat # # Functions: -def create_backup(): +def create_db_backup(): backup_file = create_automated_backup(db_path, backup_dir) return f"Backup created: {backup_file}" @@ -42,18 +42,7 @@ def create_backup_tab(): create_button = gr.Button("Create Backup") create_output = gr.Textbox(label="Result") with gr.Column(): - create_button.click(create_backup, inputs=[], outputs=create_output) - - -def create_view_backups_tab(): - with gr.TabItem("View Backups", visible=True): - gr.Markdown("# Browse available backups") - with gr.Row(): - with gr.Column(): - view_button = gr.Button("View Backups") - with gr.Column(): - backup_list = gr.Textbox(label="Available Backups") - view_button.click(list_backups, inputs=[], outputs=backup_list) + create_button.click(create_db_backup, inputs=[], outputs=create_output) def create_restore_backup_tab(): diff --git a/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py b/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..73d1be040aa3eeaef8dc90b129717a30128e1258 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Backup_RAG_Notes_Character_Chat_tab.py @@ -0,0 +1,195 @@ +# Backup_Functionality.py +# Functionality for managing database backups +# +# Imports: +import os +import shutil +import gradio as gr +from typing import Dict, List +# +# Local Imports: +from App_Function_Libraries.DB.DB_Manager import create_automated_backup +from App_Function_Libraries.DB.DB_Backups import create_backup, create_incremental_backup, restore_single_db_backup + + +# +# End of Imports +####################################################################################################################### +# +# Functions: + +def get_db_specific_backups(backup_dir: str, db_name: str) -> List[str]: + """Get list of backups specific to a database.""" + all_backups = [f for f in os.listdir(backup_dir) if f.endswith(('.db', '.sqlib'))] + db_specific_backups = [ + backup for backup in all_backups + if backup.startswith(f"{db_name}_") + ] + return sorted(db_specific_backups, reverse=True) # Most recent first + +def create_backup_tab(db_path: str, backup_dir: str, db_name: str): + """Create the backup creation tab for a database.""" + gr.Markdown("## Create Database Backup") + gr.Markdown(f"This will create a backup in the directory: `{backup_dir}`") + with gr.Row(): + with gr.Column(): + #automated_backup_btn = gr.Button("Create Simple Backup") + full_backup_btn = gr.Button("Create Full Backup") + incr_backup_btn = gr.Button("Create Incremental Backup") + with gr.Column(): + backup_output = gr.Textbox(label="Result") + + def create_db_backup(): + backup_file = create_automated_backup(db_path, backup_dir) + return f"Backup created: {backup_file}" + + # automated_backup_btn.click( + # fn=create_db_backup, + # inputs=[], + # outputs=[backup_output] + # ) + full_backup_btn.click( + fn=lambda: create_backup(db_path, backup_dir, db_name), + inputs=[], + outputs=[backup_output] + ) + incr_backup_btn.click( + fn=lambda: create_incremental_backup(db_path, backup_dir, db_name), + inputs=[], + outputs=[backup_output] + ) + +def create_view_backups_tab(backup_dir: str, db_name: str): + """Create the backup viewing tab for a database.""" + gr.Markdown("## Available Backups") + with gr.Row(): + with gr.Column(): + view_btn = gr.Button("Refresh Backup List") + with gr.Column(): + backup_list = gr.Textbox(label="Available Backups") + + def list_db_backups(): + """List backups specific to this database.""" + backups = get_db_specific_backups(backup_dir, db_name) + return "\n".join(backups) if backups else f"No backups found for {db_name} database" + + view_btn.click( + fn=list_db_backups, + inputs=[], + outputs=[backup_list] + ) + +def validate_backup_name(backup_name: str, db_name: str) -> bool: + """Validate that the backup name matches the database being restored.""" + # Check if backup name starts with the database name prefix and has valid extension + valid_prefixes = [ + f"{db_name}_backup_", # Full backup prefix + f"{db_name}_incremental_" # Incremental backup prefix + ] + has_valid_prefix = any(backup_name.startswith(prefix) for prefix in valid_prefixes) + has_valid_extension = backup_name.endswith(('.db', '.sqlib')) + return has_valid_prefix and has_valid_extension + +def create_restore_backup_tab(db_path: str, backup_dir: str, db_name: str): + """Create the backup restoration tab for a database.""" + gr.Markdown("## Restore Database") + gr.Markdown("⚠️ **Warning**: Restoring a backup will overwrite the current database.") + with gr.Row(): + with gr.Column(): + backup_input = gr.Textbox(label="Backup Filename") + restore_btn = gr.Button("Restore", variant="primary") + with gr.Column(): + restore_output = gr.Textbox(label="Result") + + def secure_restore(backup_name: str) -> str: + """Restore backup with validation checks.""" + if not backup_name: + return "Please enter a backup filename" + + # Validate backup name format + if not validate_backup_name(backup_name, db_name): + return f"Invalid backup file. Please select a backup file that starts with '{db_name}_backup_' or '{db_name}_incremental_'" + + # Check if backup exists + backup_path = os.path.join(backup_dir, backup_name) + if not os.path.exists(backup_path): + return f"Backup file not found: {backup_name}" + + # Proceed with restore + return restore_single_db_backup(db_path, backup_dir, db_name, backup_name) + + restore_btn.click( + fn=secure_restore, + inputs=[backup_input], + outputs=[restore_output] + ) + +def create_media_db_tabs(db_config: Dict[str, str]): + """Create all tabs for the Media database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='media' + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='media' + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='media' + ) + +def create_rag_chat_tabs(db_config: Dict[str, str]): + """Create all tabs for the RAG Chat database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='rag_qa' # Updated to match DB_Manager.py + ) + +def create_character_chat_tabs(db_config: Dict[str, str]): + """Create all tabs for the Character Chat database.""" + create_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='chatDB' # Updated to match DB_Manager.py + ) + create_view_backups_tab( + backup_dir=db_config['backup_dir'], + db_name='chatDB' # Updated to match DB_Manager.py + ) + create_restore_backup_tab( + db_path=db_config['db_path'], + backup_dir=db_config['backup_dir'], + db_name='chatDB' + ) + +def create_database_management_interface( + media_db_config: Dict[str, str], + rag_db_config: Dict[str, str], + char_db_config: Dict[str, str] +): + """Create the main database management interface with tabs for each database.""" + with gr.TabItem("Media Database", id="media_db_group", visible=True): + create_media_db_tabs(media_db_config) + + with gr.TabItem("RAG Chat Database", id="rag_chat_group", visible=True): + create_rag_chat_tabs(rag_db_config) + + with gr.TabItem("Character Chat Database", id="character_chat_group", visible=True): + create_character_chat_tabs(char_db_config) + +# +# End of Functions +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py b/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py index cc455dfa67109a7f9ab95ab17fe1d67ae9142b67..7888b53532e0215c01b1847e464d3789254332bd 100644 --- a/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Book_Ingestion_tab.py @@ -8,69 +8,113 @@ # #################### # Imports +import logging # # External Imports import gradio as gr # # Local Imports -from App_Function_Libraries.Books.Book_Ingestion_Lib import process_zip_file, import_epub, import_file_handler +from App_Function_Libraries.Books.Book_Ingestion_Lib import import_file_handler +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name # ######################################################################################################################## # # Functions: - - def create_import_book_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("Ebook(epub) Files", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Import .epub files") - gr.Markdown("Upload a single .epub file or a .zip file containing multiple .epub files") + gr.Markdown("Upload multiple .epub files or a .zip file containing multiple .epub files") gr.Markdown( "🔗 **How to remove DRM from your ebooks:** [Reddit Guide](https://www.reddit.com/r/Calibre/comments/1ck4w8e/2024_guide_on_removing_drm_from_kobo_kindle_ebooks/)") - import_file = gr.File(label="Upload file for import", file_types=[".epub", ".zip"]) - title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content (for single files)") - author_input = gr.Textbox(label="Author", placeholder="Enter the author's name (for single files)") - keywords_input = gr.Textbox(label="Keywords (like genre or publish year)", - placeholder="Enter keywords, comma-separated") - system_prompt_input = gr.Textbox(label="System Prompt", lines=3, - value="""" - You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded + # Updated to support multiple files + import_files = gr.File( + label="Upload files for import", + file_count="multiple", + file_types=[".epub", ".zip", ".html", ".htm", ".xml", ".opml"] + ) - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] - """, ) - custom_prompt_input = gr.Textbox(label="Custom User Prompt", - placeholder="Enter a custom user prompt for summarization (optional)") + # Optional fields for overriding auto-extracted metadata + author_input = gr.Textbox( + label="Author Override (optional)", + placeholder="Enter author name to override auto-extracted metadata" + ) + keywords_input = gr.Textbox( + label="Keywords (like genre or publish year)", + placeholder="Enter keywords, comma-separated - will be applied to all uploaded books" + ) + system_prompt_input = gr.Textbox( + label="System Prompt", + lines=3, + value="""" + You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] + """ + ) + custom_prompt_input = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter a custom user prompt for summarization (optional)" + ) auto_summarize_checkbox = gr.Checkbox(label="Auto-summarize", value=False) + + # API configuration api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace"], - label="API for Auto-summarization" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" ) api_key_input = gr.Textbox(label="API Key", type="password") # Chunking options - max_chunk_size = gr.Slider(minimum=100, maximum=2000, value=500, step=50, label="Max Chunk Size") - chunk_overlap = gr.Slider(minimum=0, maximum=500, value=200, step=10, label="Chunk Overlap") - custom_chapter_pattern = gr.Textbox(label="Custom Chapter Pattern (optional)", - placeholder="Enter a custom regex pattern for chapter detection") + max_chunk_size = gr.Slider( + minimum=100, + maximum=2000, + value=500, + step=50, + label="Max Chunk Size" + ) + chunk_overlap = gr.Slider( + minimum=0, + maximum=500, + value=200, + step=10, + label="Chunk Overlap" + ) + custom_chapter_pattern = gr.Textbox( + label="Custom Chapter Pattern (optional)", + placeholder="Enter a custom regex pattern for chapter detection" + ) + import_button = gr.Button("Import eBooks") - import_button = gr.Button("Import eBook(s)") with gr.Column(): with gr.Row(): import_output = gr.Textbox(label="Import Status", lines=10, interactive=False) @@ -78,10 +122,10 @@ def create_import_book_tab(): import_button.click( fn=import_file_handler, inputs=[ - import_file, - title_input, + import_files, # Now handles multiple files author_input, keywords_input, + system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, @@ -93,8 +137,8 @@ def create_import_book_tab(): outputs=import_output ) - return import_file, title_input, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output + return import_files, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output # # End of File -######################################################################################################################## \ No newline at end of file +######################################################################################################################## diff --git a/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py b/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py index 86c173ea2961306aea6edb60c29d62cd2b4decf0..6a49ae16c6b20fa8b38d253ef02a4ad550747d27 100644 --- a/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/Character_Chat_tab.py @@ -2,10 +2,10 @@ # Description: Library for character card import functions # # Imports +from datetime import datetime import re import tempfile import uuid -from datetime import datetime import json import logging import io @@ -21,7 +21,7 @@ import gradio as gr from App_Function_Libraries.Character_Chat.Character_Chat_Lib import validate_character_book, validate_v2_card, \ replace_placeholders, replace_user_placeholder, extract_json_from_image, parse_character_book, \ load_chat_and_character, load_chat_history, load_character_and_image, extract_character_id, load_character_wrapper -from App_Function_Libraries.Chat import chat +from App_Function_Libraries.Chat.Chat_Functions import chat, approximate_token_count from App_Function_Libraries.DB.Character_Chat_DB import ( add_character_card, get_character_cards, @@ -32,9 +32,12 @@ from App_Function_Libraries.DB.Character_Chat_DB import ( update_character_chat, delete_character_chat, delete_character_card, - update_character_card, search_character_chats, + update_character_card, search_character_chats, save_chat_history_to_character_db, ) -from App_Function_Libraries.Utils.Utils import sanitize_user_input +from App_Function_Libraries.Utils.Utils import sanitize_user_input, format_api_name, global_api_endpoints, \ + default_api_endpoint, load_comprehensive_config + + # ############################################################################################################ # @@ -252,8 +255,37 @@ def export_all_characters(): # Gradio tabs def create_character_card_interaction_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Chat with a Character Card", visible=True): gr.Markdown("# Chat with a Character Card") + with gr.Row(): + with gr.Column(scale=1): + # Checkbox to Decide Whether to Save Chats by Default + config = load_comprehensive_config() + auto_save_value = config.get('auto-save', 'save_character_chats', fallback='False') + auto_save_checkbox = gr.Checkbox(label="Save chats automatically", value=auto_save_value) + chat_media_name = gr.Textbox(label="Custom Chat Name (optional)", visible=True) + save_chat_history_to_db = gr.Button("Save Chat History to Database") + save_status = gr.Textbox(label="Status", interactive=False) + with gr.Column(scale=2): + gr.Markdown("## Search and Load Existing Chats") + chat_search_query = gr.Textbox( + label="Search Chats", + placeholder="Enter chat name or keywords to search" + ) + chat_search_button = gr.Button("Search Chats") + chat_search_dropdown = gr.Dropdown(label="Search Results", choices=[], visible=False) + load_chat_button = gr.Button("Load Selected Chat", visible=False) + with gr.Row(): with gr.Column(scale=1): character_image = gr.Image(label="Character Image", type="pil") @@ -265,13 +297,10 @@ def create_character_card_interaction_tab(): load_characters_button = gr.Button("Load Existing Characters") character_dropdown = gr.Dropdown(label="Select Character", choices=[]) user_name_input = gr.Textbox(label="Your Name", placeholder="Enter your name here") + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=[ - "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", - "Custom-OpenAI-API" - ], - value="HuggingFace", + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, label="API for Interaction (Mandatory)" ) api_key_input = gr.Textbox( @@ -281,24 +310,8 @@ def create_character_card_interaction_tab(): temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Temperature" ) - import_chat_button = gr.Button("Import Chat History") chat_file_upload = gr.File(label="Upload Chat History JSON", visible=True) - - # Chat History Import and Search - gr.Markdown("## Search and Load Existing Chats") - chat_search_query = gr.Textbox( - label="Search Chats", - placeholder="Enter chat name or keywords to search" - ) - chat_search_button = gr.Button("Search Chats") - chat_search_dropdown = gr.Dropdown(label="Search Results", choices=[], visible=False) - load_chat_button = gr.Button("Load Selected Chat", visible=False) - - # Checkbox to Decide Whether to Save Chats by Default - auto_save_checkbox = gr.Checkbox(label="Save chats automatically", value=True) - chat_media_name = gr.Textbox(label="Custom Chat Name (optional)", visible=True) - save_chat_history_to_db = gr.Button("Save Chat History to Database") - save_status = gr.Textbox(label="Save Status", interactive=False) + import_chat_button = gr.Button("Import Chat History") with gr.Column(scale=2): chat_history = gr.Chatbot(label="Conversation", height=800) @@ -307,6 +320,7 @@ def create_character_card_interaction_tab(): answer_for_me_button = gr.Button("Answer for Me") continue_talking_button = gr.Button("Continue Talking") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") save_snapshot_button = gr.Button("Save Chat Snapshot") update_chat_dropdown = gr.Dropdown(label="Select Chat to Update", choices=[], visible=False) @@ -491,23 +505,114 @@ def create_character_card_interaction_tab(): return history, save_status + def validate_chat_history(chat_history: List[Tuple[Optional[str], str]]) -> bool: + """ + Validate the chat history format and content. + + Args: + chat_history: List of message tuples (user_message, bot_message) + + Returns: + bool: True if valid, False if invalid + """ + if not isinstance(chat_history, list): + return False + + for entry in chat_history: + if not isinstance(entry, tuple) or len(entry) != 2: + return False + # First element can be None (for system messages) or str + if not (entry[0] is None or isinstance(entry[0], str)): + return False + # Second element (bot response) must be str and not empty + if not isinstance(entry[1], str) or not entry[1].strip(): + return False + + return True + + def sanitize_conversation_name(name: str) -> str: + """ + Sanitize the conversation name. + + Args: + name: Raw conversation name + + Returns: + str: Sanitized conversation name + """ + # Remove any non-alphanumeric characters except spaces and basic punctuation + sanitized = re.sub(r'[^a-zA-Z0-9\s\-_.]', '', name) + # Limit length + sanitized = sanitized[:100] + # Ensure it's not empty + if not sanitized.strip(): + sanitized = f"Chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + return sanitized + def save_chat_history_to_db_wrapper( - chat_history, conversation_id, media_content, - chat_media_name, char_data, auto_save - ): - if not char_data or not chat_history: - return "No character or chat history available.", "" + chat_history: List[Tuple[Optional[str], str]], + conversation_id: str, + media_content: Dict, + chat_media_name: str, + char_data: Dict, + auto_save: bool + ) -> Tuple[str, str]: + """ + Save chat history to the database with validation. - character_id = char_data.get('id') - if not character_id: - return "Character ID not found.", "" + Args: + chat_history: List of message tuples + conversation_id: Current conversation ID + media_content: Media content metadata + chat_media_name: Custom name for the chat + char_data: Character data dictionary + auto_save: Auto-save flag - conversation_name = chat_media_name or f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - chat_id = add_character_chat(character_id, conversation_name, chat_history) - if chat_id: - return f"Chat saved successfully with ID {chat_id}.", "" - else: - return "Failed to save chat.", "" + Returns: + Tuple[str, str]: (status message, detail message) + """ + try: + # Basic input validation + if not chat_history: + return "No chat history to save.", "" + + if not validate_chat_history(chat_history): + return "Invalid chat history format.", "Please ensure the chat history is valid." + + if not char_data: + return "No character selected.", "Please select a character first." + + character_id = char_data.get('id') + if not character_id: + return "Invalid character data: No character ID found.", "" + + # Sanitize and prepare conversation name + conversation_name = sanitize_conversation_name( + chat_media_name if chat_media_name.strip() + else f"Chat with {char_data.get('name', 'Unknown')} - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ) + + # Save to the database using your existing function + chat_id = save_chat_history_to_character_db( + character_id=character_id, + conversation_name=conversation_name, + chat_history=chat_history + ) + + if chat_id: + success_message = ( + f"Chat saved successfully!\n" + f"ID: {chat_id}\n" + f"Name: {conversation_name}\n" + f"Messages: {len(chat_history)}" + ) + return success_message, "" + else: + return "Failed to save chat to database.", "Database operation failed." + + except Exception as e: + logging.error(f"Error saving chat history: {str(e)}", exc_info=True) + return f"Error saving chat: {str(e)}", "Please check the logs for more details." def update_character_info(name): return load_character_and_image(name, user_name.value) @@ -871,6 +976,10 @@ def create_character_card_interaction_tab(): auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) continue_talking_button.click( @@ -885,6 +994,10 @@ def create_character_card_interaction_tab(): auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) import_card_button.click( @@ -903,6 +1016,10 @@ def create_character_card_interaction_tab(): fn=clear_chat_history, inputs=[character_data, user_name_input], outputs=[chat_history, character_data] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) character_dropdown.change( @@ -928,7 +1045,13 @@ def create_character_card_interaction_tab(): auto_save_checkbox ], outputs=[chat_history, save_status] - ).then(lambda: "", outputs=user_input) + ).then( + lambda: "", outputs=user_input + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] + ) regenerate_button.click( fn=regenerate_last_message, @@ -942,6 +1065,10 @@ def create_character_card_interaction_tab(): auto_save_checkbox ], outputs=[chat_history, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) import_chat_button.click( @@ -951,8 +1078,12 @@ def create_character_card_interaction_tab(): chat_file_upload.change( fn=import_chat_history, - inputs=[chat_file_upload, chat_history, character_data], + inputs=[chat_file_upload, chat_history, character_data, user_name_input], outputs=[chat_history, character_data, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) save_chat_history_to_db.click( @@ -1009,6 +1140,10 @@ def create_character_card_interaction_tab(): fn=load_selected_chat_from_search, inputs=[chat_search_dropdown, user_name_input], outputs=[character_data, chat_history, character_image, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history], + outputs=[token_count_display] ) # Show Load Chat Button when a chat is selected @@ -1023,8 +1158,8 @@ def create_character_card_interaction_tab(): def create_character_chat_mgmt_tab(): - with gr.TabItem("Character and Chat Management", visible=True): - gr.Markdown("# Character and Chat Management") + with gr.TabItem("Character Chat Management", visible=True): + gr.Markdown("# Character Chat Management") with gr.Row(): # Left Column: Character Import and Chat Management @@ -1057,13 +1192,17 @@ def create_character_chat_mgmt_tab(): gr.Markdown("## Chat Management") select_chat = gr.Dropdown(label="Select Chat", choices=[], visible=False, interactive=True) load_chat_button = gr.Button("Load Selected Chat", visible=False) - conversation_list = gr.Dropdown(label="Select Conversation or Character", choices=[]) + conversation_list = gr.Dropdown(label="Select Conversation", choices=[]) conversation_mapping = gr.State({}) with gr.Tabs(): with gr.TabItem("Edit", visible=True): chat_content = gr.TextArea(label="Chat/Character Content (JSON)", lines=20, max_lines=50) save_button = gr.Button("Save Changes") + export_chat_button = gr.Button("Export Current Conversation", variant="secondary") + export_all_chats_button = gr.Button("Export All Character Conversations", variant="secondary") + export_file = gr.File(label="Downloaded File", visible=False) + export_status = gr.Markdown("") delete_button = gr.Button("Delete Conversation/Character", variant="stop") with gr.TabItem("Preview", visible=True): @@ -1306,6 +1445,90 @@ def create_character_chat_mgmt_tab(): return "Import results:\n" + "\n".join(results) + def export_current_conversation(selected_chat): + if not selected_chat: + return "Please select a conversation to export.", None + + try: + chat_id = int(selected_chat.split('(ID: ')[1].rstrip(')')) + chat = get_character_chat_by_id(chat_id) + + if not chat: + return "Selected chat not found.", None + + # Ensure chat_history is properly parsed + chat_history = chat['chat_history'] + if isinstance(chat_history, str): + chat_history = json.loads(chat_history) + + export_data = { + "conversation_id": chat['id'], + "conversation_name": chat['conversation_name'], + "character_id": chat['character_id'], + "chat_history": chat_history, + "exported_at": datetime.now().isoformat() + } + + # Convert to JSON string + json_str = json.dumps(export_data, indent=2, ensure_ascii=False) + + # Create file name + file_name = f"conversation_{chat['id']}_{chat['conversation_name']}.json" + + # Return file for download + return "Conversation exported successfully!", (file_name, json_str, "application/json") + + except Exception as e: + logging.error(f"Error exporting conversation: {e}") + return f"Error exporting conversation: {str(e)}", None + + def export_all_character_conversations(character_selection): + if not character_selection: + return "Please select a character first.", None + + try: + character_id = int(character_selection.split('(ID: ')[1].rstrip(')')) + character = get_character_card_by_id(character_id) + chats = get_character_chats(character_id=character_id) + + if not chats: + return "No conversations found for this character.", None + + # Process chat histories + conversations = [] + for chat in chats: + chat_history = chat['chat_history'] + if isinstance(chat_history, str): + chat_history = json.loads(chat_history) + + conversations.append({ + "conversation_id": chat['id'], + "conversation_name": chat['conversation_name'], + "chat_history": chat_history + }) + + export_data = { + "character": { + "id": character['id'], + "name": character['name'] + }, + "conversations": conversations, + "exported_at": datetime.now().isoformat() + } + + # Convert to JSON string + json_str = json.dumps(export_data, indent=2, ensure_ascii=False) + + # Create file name + file_name = f"all_conversations_{character['name']}_{character['id']}.json" + + # Return file for download + return "All conversations exported successfully!", (file_name, json_str, "application/json") + + except Exception as e: + logging.error(f"Error exporting all conversations: {e}") + return f"Error exporting conversations: {str(e)}", None + # Register new callback for character import import_characters_button.click( fn=import_multiple_characters, @@ -1368,6 +1591,18 @@ def create_character_chat_mgmt_tab(): outputs=select_character ) + export_chat_button.click( + fn=export_current_conversation, + inputs=[select_chat], + outputs=[export_status, export_file] + ) + + export_all_chats_button.click( + fn=export_all_character_conversations, + inputs=[select_character], + outputs=[export_status, export_file] + ) + return ( character_files, import_characters_button, import_status, search_query, search_button, search_results, search_status, diff --git a/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py b/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py index 5d1738052b94369997ea157b13ba718f34b01ed8..d4f0598cdb03eb01104f2c212012ca5eb08f3616 100644 --- a/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py +++ b/App_Function_Libraries/Gradio_UI/Character_interaction_tab.py @@ -17,9 +17,12 @@ import gradio as gr from PIL import Image # # Local Imports -from App_Function_Libraries.Chat import chat, load_characters, save_chat_history_to_db_wrapper +from App_Function_Libraries.Chat.Chat_Functions import chat, load_characters, save_chat_history_to_db_wrapper from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper from App_Function_Libraries.Gradio_UI.Writing_tab import generate_writing_feedback +from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints + + # ######################################################################################################################## # @@ -253,6 +256,16 @@ def character_interaction(character1: str, character2: str, api_endpoint: str, a def create_multiple_character_chat_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Multi-Character Chat", visible=True): characters, conversation, current_character, other_character = character_interaction_setup() @@ -264,13 +277,12 @@ def create_multiple_character_chat_tab(): character_selectors = [gr.Dropdown(label=f"Character {i + 1}", choices=list(characters.keys())) for i in range(4)] - api_endpoint = gr.Dropdown(label="API Endpoint", - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", - "Mistral", - "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", - "ollama", "HuggingFace", - "Custom-OpenAI-API"], - value="HuggingFace") + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Interaction (Optional)" + ) api_key = gr.Textbox(label="API Key (if required)", type="password") temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7) scenario = gr.Textbox(label="Scenario (optional)", lines=3) @@ -393,17 +405,26 @@ def create_multiple_character_chat_tab(): # From `Fuzzlewumper` on Reddit. def create_narrator_controlled_conversation_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Narrator-Controlled Conversation", visible=True): gr.Markdown("# Narrator-Controlled Conversation") with gr.Row(): with gr.Column(scale=1): + # Refactored API selection dropdown api_endpoint = gr.Dropdown( - label="API Endpoint", - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", - "Custom-OpenAI-API"], - value="HuggingFace" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" ) api_key = gr.Textbox(label="API Key (if required)", type="password") temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7) diff --git a/App_Function_Libraries/Gradio_UI/Chat_ui.py b/App_Function_Libraries/Gradio_UI/Chat_ui.py index a8b55e68ac4f20028a858b4f261263fc3b46ce5d..0ec6ebf65ad5fc59a738e251199c0ea618687e25 100644 --- a/App_Function_Libraries/Gradio_UI/Chat_ui.py +++ b/App_Function_Libraries/Gradio_UI/Chat_ui.py @@ -2,23 +2,25 @@ # Description: Chat interface functions for Gradio # # Imports -import html -import json import logging import os import sqlite3 +import time from datetime import datetime # # External Imports import gradio as gr # # Local Imports -from App_Function_Libraries.Chat import chat, save_chat_history, update_chat_content, save_chat_history_to_db_wrapper -from App_Function_Libraries.DB.DB_Manager import add_chat_message, search_chat_conversations, create_chat_conversation, \ - get_chat_messages, update_chat_message, delete_chat_message, load_preset_prompts, db +from App_Function_Libraries.Chat.Chat_Functions import approximate_token_count, chat, save_chat_history, \ + update_chat_content, save_chat_history_to_db_wrapper +from App_Function_Libraries.DB.DB_Manager import db, load_chat_history, start_new_conversation, \ + save_message, search_conversations_by_keywords, \ + get_all_conversations, delete_messages_in_conversation, search_media_db, list_prompts +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_db_connection from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_user_prompt - - +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram +from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints # # ######################################################################################################################## @@ -91,10 +93,9 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, # Create a new conversation media_id = media_content.get('id', None) conversation_name = f"Chat about {media_content.get('title', 'Unknown Media')} - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - conversation_id = create_chat_conversation(media_id, conversation_name) - + conversation_id = start_new_conversation(title=conversation_name, media_id=media_id) # Add user message to the database - user_message_id = add_chat_message(conversation_id, "user", message) + user_message_id = save_message(conversation_id, role="user", content=message) # Include the selected parts and custom_prompt only for the first message if not history and selected_parts: @@ -113,7 +114,7 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, if save_conversation: # Add assistant message to the database - add_chat_message(conversation_id, "assistant", bot_message) + save_message(conversation_id, role="assistant", content=bot_message) # Update history new_history = history + [(message, bot_message)] @@ -123,51 +124,57 @@ def chat_wrapper(message, history, media_content, selected_parts, api_endpoint, logging.error(f"Error in chat wrapper: {str(e)}") return "An error occurred.", history, conversation_id + def search_conversations(query): + """Convert existing chat search to use RAG chat functions""" try: - conversations = search_chat_conversations(query) - if not conversations: - print(f"Debug - Search Conversations - No results found for query: {query}") + # Use the RAG search function - search by title if given a query + if query and query.strip(): + results, _, _ = search_conversations_by_keywords( + title_query=query.strip() + ) + else: + # Get all conversations if no query + results, _, _ = get_all_conversations() + + if not results: return gr.update(choices=[]) + # Format choices to match existing UI format conversation_options = [ - (f"{c['conversation_name']} (Media: {c['media_title']}, ID: {c['id']})", c['id']) - for c in conversations + (f"{conv['title']} (ID: {conv['conversation_id'][:8]})", conv['conversation_id']) + for conv in results ] - print(f"Debug - Search Conversations - Options: {conversation_options}") + return gr.update(choices=conversation_options) except Exception as e: - print(f"Debug - Search Conversations - Error: {str(e)}") + logging.error(f"Error searching conversations: {str(e)}") return gr.update(choices=[]) def load_conversation(conversation_id): + """Convert existing load to use RAG chat functions""" if not conversation_id: return [], None - messages = get_chat_messages(conversation_id) - history = [ - (msg['message'], None) if msg['sender'] == 'user' else (None, msg['message']) - for msg in messages - ] - return history, conversation_id - - -def update_message_in_chat(message_id, new_text, history): - update_chat_message(message_id, new_text) - updated_history = [(msg1, msg2) if msg1[1] != message_id and msg2[1] != message_id - else ((new_text, msg1[1]) if msg1[1] == message_id else (new_text, msg2[1])) - for msg1, msg2 in history] - return updated_history + try: + # Use RAG load function + messages, _, _ = load_chat_history(conversation_id) + # Convert to chatbot history format + history = [ + (content, None) if role == 'user' else (None, content) + for role, content in messages + ] -def delete_message_from_chat(message_id, history): - delete_chat_message(message_id) - updated_history = [(msg1, msg2) for msg1, msg2 in history if msg1[1] != message_id and msg2[1] != message_id] - return updated_history + return history, conversation_id + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], None -def regenerate_last_message(history, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, system_prompt): +def regenerate_last_message(history, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, + system_prompt): if not history: return history, "No messages to regenerate." @@ -200,7 +207,56 @@ def regenerate_last_message(history, media_content, selected_parts, api_endpoint return new_history, "Last message regenerated successfully." + +def update_dropdown_multiple(query, search_type, keywords=""): + """Updated function to handle multiple search results using search_media_db""" + try: + # Define search fields based on search type + search_fields = [] + if search_type.lower() == "keyword": + # When searching by keyword, we'll search across multiple fields + search_fields = ["title", "content", "author"] + else: + # Otherwise use the specific field + search_fields = [search_type.lower()] + + # Perform the search + results = search_media_db( + search_query=query, + search_fields=search_fields, + keywords=keywords, + page=1, + results_per_page=50 # Adjust as needed + ) + + # Process results + item_map = {} + formatted_results = [] + + for row in results: + id, url, title, type_, content, author, date, prompt, summary = row + # Create a display text that shows relevant info + display_text = f"{title} - {author or 'Unknown'} ({date})" + formatted_results.append(display_text) + item_map[display_text] = id + + return gr.update(choices=formatted_results), item_map + except Exception as e: + logging.error(f"Error in update_dropdown_multiple: {str(e)}") + return gr.update(choices=[]), {} + + def create_chat_interface(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None custom_css = """ .chatbot-container .message-wrap .message { font-size: 14px !important; @@ -215,9 +271,19 @@ def create_chat_interface(): with gr.Row(): with gr.Column(scale=1): - search_query_input = gr.Textbox(label="Search Query", placeholder="Enter your search query here...") - search_type_input = gr.Radio(choices=["Title", "URL", "Keyword", "Content"], value="Title", - label="Search By") + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) search_button = gr.Button("Search") items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) item_mapping = gr.State({}) @@ -237,53 +303,60 @@ def create_chat_interface(): with gr.Row(): load_conversations_btn = gr.Button("Load Selected Conversation") - api_endpoint = gr.Dropdown(label="Select API Endpoint", - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", - "Mistral", "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", - "HuggingFace"]) + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) api_key = gr.Textbox(label="API Key (if required)", type="password") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", value=False, visible=True) preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) - user_prompt = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) - system_prompt_input = gr.Textbox(label="System Prompt", - value="You are a helpful AI assitant", - lines=3, - visible=False) + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown(label="Select Preset Prompt", + choices=[], + visible=False) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + system_prompt_input = gr.Textbox(label="System Prompt", + value="You are a helpful AI assistant", + lines=3, + visible=False) + with gr.Row(): + user_prompt = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) with gr.Column(scale=2): - chatbot = gr.Chatbot(height=600, elem_classes="chatbot-container") + chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") msg = gr.Textbox(label="Enter your message") submit = gr.Button("Submit") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") - edit_message_id = gr.Number(label="Message ID to Edit", visible=False) - edit_message_text = gr.Textbox(label="Edit Message", visible=False) - update_message_button = gr.Button("Update Message", visible=False) - - delete_message_id = gr.Number(label="Message ID to Delete", visible=False) - delete_message_button = gr.Button("Delete Message", visible=False) - chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) save_chat_history_as_file = gr.Button("Save Chat History as File") download_file = gr.File(label="Download Chat History") - save_status = gr.Textbox(label="Save Status", interactive=False) # Restore original functionality search_button.click( - fn=update_dropdown, - inputs=[search_query_input, search_type_input], + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], outputs=[items_output, item_mapping] ) @@ -314,21 +387,72 @@ def create_chat_interface(): clear_chat, outputs=[chatbot, conversation_id] ) + + # Function to handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[user_prompt, system_prompt_input] ) + custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[user_prompt, system_prompt_input] ) - preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), - inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] - ) + submit.click( chat_wrapper, inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, conversation_id, @@ -341,6 +465,10 @@ def create_chat_interface(): ).then( # Clear the user prompt after the first message lambda: (gr.update(value=""), gr.update(value="")), outputs=[user_prompt, system_prompt_input] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) items_output.change( @@ -348,6 +476,7 @@ def create_chat_interface(): inputs=[items_output, use_content, use_summary, use_prompt, item_mapping], outputs=[media_content, selected_parts] ) + use_content.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], outputs=[selected_parts]) use_summary.change(update_selected_parts, inputs=[use_content, use_summary, use_prompt], @@ -377,18 +506,6 @@ def create_chat_interface(): outputs=[chat_history] ) - update_message_button.click( - update_message_in_chat, - inputs=[edit_message_id, edit_message_text, chat_history], - outputs=[chatbot] - ) - - delete_message_button.click( - delete_message_from_chat, - inputs=[delete_message_id, chat_history], - outputs=[chatbot] - ) - save_chat_history_as_file.click( save_chat_history, inputs=[chatbot, conversation_id], @@ -403,15 +520,28 @@ def create_chat_interface(): regenerate_button.click( regenerate_last_message, - inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, system_prompt_input], + inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temperature, + system_prompt_input], outputs=[chatbot, save_status] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) - chatbot.select(show_edit_message, None, [edit_message_text, edit_message_id, update_message_button]) - chatbot.select(show_delete_message, None, [delete_message_id, delete_message_button]) - def create_chat_interface_stacked(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + custom_css = """ .chatbot-container .message-wrap .message { font-size: 14px !important; @@ -426,9 +556,19 @@ def create_chat_interface_stacked(): with gr.Row(): with gr.Column(): - search_query_input = gr.Textbox(label="Search Query", placeholder="Enter your search query here...") - search_type_input = gr.Radio(choices=["Title", "URL", "Keyword", "Content"], value="Title", - label="Search By") + search_query_input = gr.Textbox( + label="Search Query", + placeholder="Enter your search query here..." + ) + search_type_input = gr.Radio( + choices=["Title", "Content", "Author", "Keyword"], + value="Keyword", + label="Search By" + ) + keyword_filter_input = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="ml, ai, python, etc..." + ) search_button = gr.Button("Search") items_output = gr.Dropdown(label="Select Item", choices=[], interactive=True) item_mapping = gr.State({}) @@ -446,45 +586,165 @@ def create_chat_interface_stacked(): search_conversations_btn = gr.Button("Search Conversations") load_conversations_btn = gr.Button("Load Selected Conversation") with gr.Column(): - api_endpoint = gr.Dropdown(label="Select API Endpoint", - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", - "OpenRouter", "Mistral", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", - "VLLM", "ollama", "HuggingFace"]) + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) api_key = gr.Textbox(label="API Key (if required)", type="password") - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=True) - system_prompt = gr.Textbox(label="System Prompt", - value="You are a helpful AI assistant.", - lines=3, - visible=True) - user_prompt = gr.Textbox(label="Custom User Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=True) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + + with gr.Row(): + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + + system_prompt = gr.Textbox( + label="System Prompt", + value="You are a helpful AI assistant.", + lines=4, + visible=False + ) + user_prompt = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter custom prompt here", + lines=4, + visible=False + ) gr.Markdown("Scroll down for the chat window...") with gr.Row(): with gr.Column(scale=1): - chatbot = gr.Chatbot(height=600, elem_classes="chatbot-container") + chatbot = gr.Chatbot(height=800, elem_classes="chatbot-container") msg = gr.Textbox(label="Enter your message") with gr.Row(): with gr.Column(): submit = gr.Button("Submit") regenerate_button = gr.Button("Regenerate Last Message") + token_count_display = gr.Number(label="Approximate Token Count", value=0, interactive=False) clear_chat_button = gr.Button("Clear Chat") chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True) save_chat_history_to_db = gr.Button("Save Chat History to DataBase") + save_status = gr.Textbox(label="Save Status", interactive=False) save_chat_history_as_file = gr.Button("Save Chat History as File") with gr.Column(): download_file = gr.File(label="Download Chat History") # Restore original functionality search_button.click( - fn=update_dropdown, - inputs=[search_query_input, search_type_input], + fn=update_dropdown_multiple, + inputs=[search_query_input, search_type_input, keyword_filter_input], outputs=[items_output, item_mapping] ) + def search_conversations(query): + try: + # Use RAG search with title search + if query and query.strip(): + results, _, _ = search_conversations_by_keywords(title_query=query.strip()) + else: + results, _, _ = get_all_conversations() + + if not results: + return gr.update(choices=[]) + + # Format choices to match UI + conversation_options = [ + (f"{conv['title']} (ID: {conv['conversation_id'][:8]})", conv['conversation_id']) + for conv in results + ] + + return gr.update(choices=conversation_options) + except Exception as e: + logging.error(f"Error searching conversations: {str(e)}") + return gr.update(choices=[]) + + def load_conversation(conversation_id): + if not conversation_id: + return [], None + + try: + # Use RAG load function + messages, _, _ = load_chat_history(conversation_id) + + # Convert to chatbot history format + history = [ + (content, None) if role == 'user' else (None, content) + for role, content in messages + ] + + return history, conversation_id + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], None + + def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, chat_name=None): + log_counter("save_chat_history_to_db_attempt") + start_time = time.time() + logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}") + + try: + # First check if we can access the database + try: + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + except sqlite3.DatabaseError as db_error: + logging.error(f"Database is corrupted or inaccessible: {str(db_error)}") + return conversation_id, gr.update( + value="Database error: The database file appears to be corrupted. Please contact support.") + + # For both new and existing conversations + try: + if not conversation_id: + title = chat_name if chat_name else "Untitled Conversation" + conversation_id = start_new_conversation(title=title) + logging.info(f"Created new conversation with ID: {conversation_id}") + + # Update existing messages + delete_messages_in_conversation(conversation_id) + for user_msg, assistant_msg in chatbot: + if user_msg: + save_message(conversation_id, "user", user_msg) + if assistant_msg: + save_message(conversation_id, "assistant", assistant_msg) + except sqlite3.DatabaseError as db_error: + logging.error(f"Database error during message save: {str(db_error)}") + return conversation_id, gr.update( + value="Database error: Unable to save messages. Please try again or contact support.") + + save_duration = time.time() - start_time + log_histogram("save_chat_history_to_db_duration", save_duration) + log_counter("save_chat_history_to_db_success") + + return conversation_id, gr.update(value="Chat history saved successfully!") + + except Exception as e: + log_counter("save_chat_history_to_db_error", labels={"error": str(e)}) + error_message = f"Failed to save chat history: {str(e)}" + logging.error(error_message, exc_info=True) + return conversation_id, gr.update(value=error_message) + def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -492,13 +752,85 @@ def create_chat_interface_stacked(): gr.update(value=prompts["system_prompt"], visible=True) ) + def clear_chat(): + return [], None, 0 # Empty history, conversation_id, and token count + clear_chat_button.click( clear_chat, - outputs=[chatbot, conversation_id] + outputs=[chatbot, conversation_id, token_count_display] ) + + # Handle custom prompt checkbox change + def on_custom_prompt_checkbox_change(is_checked): + return ( + gr.update(visible=is_checked), + gr.update(visible=is_checked) + ) + + custom_prompt_checkbox.change( + fn=on_custom_prompt_checkbox_change, + inputs=[custom_prompt_checkbox], + outputs=[user_prompt, system_prompt] + ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected preset_prompt.change( update_prompts, - inputs=preset_prompt, + inputs=[preset_prompt], outputs=[user_prompt, system_prompt] ) @@ -507,13 +839,14 @@ def create_chat_interface_stacked(): inputs=[msg, chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, conversation_id, save_conversation, temp, system_prompt], outputs=[msg, chatbot, conversation_id] - ).then( # Clear the message box after submission + ).then( lambda x: gr.update(value=""), inputs=[chatbot], outputs=[msg] - ).then( # Clear the user prompt after the first message - lambda: gr.update(value=""), - outputs=[user_prompt, system_prompt] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) items_output.change( @@ -559,18 +892,31 @@ def create_chat_interface_stacked(): save_chat_history_to_db.click( save_chat_history_to_db_wrapper, inputs=[chatbot, conversation_id, media_content, chat_media_name], - outputs=[conversation_id, gr.Textbox(label="Save Status")] + outputs=[conversation_id, save_status] ) regenerate_button.click( regenerate_last_message, inputs=[chatbot, media_content, selected_parts, api_endpoint, api_key, user_prompt, temp, system_prompt], outputs=[chatbot, gr.Textbox(label="Regenerate Status")] + ).then( + lambda history: approximate_token_count(history), + inputs=[chatbot], + outputs=[token_count_display] ) -# FIXME - System prompts def create_chat_interface_multi_api(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None custom_css = """ .chatbot-container .message-wrap .message { font-size: 14px !important; @@ -596,9 +942,31 @@ def create_chat_interface_multi_api(): use_summary = gr.Checkbox(label="Use Summary") use_prompt = gr.Checkbox(label="Use Prompt") with gr.Column(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", choices=load_preset_prompts(), visible=True) - system_prompt = gr.Textbox(label="System Prompt", value="You are a helpful AI assistant.", lines=5) - user_prompt = gr.Textbox(label="Modify Prompt (Prefixed to your message every time)", lines=5, value="", visible=True) + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", + value=False, + visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", + value=False, + visible=True) + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown(label="Select Preset Prompt", + choices=[], + visible=False) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + system_prompt = gr.Textbox(label="System Prompt", + value="You are a helpful AI assistant.", + lines=5, + visible=True) + user_prompt = gr.Textbox(label="Modify Prompt (Prefixed to your message every time)", lines=5, + value="", visible=True) with gr.Row(): chatbots = [] @@ -606,17 +974,23 @@ def create_chat_interface_multi_api(): api_keys = [] temperatures = [] regenerate_buttons = [] + token_count_displays = [] for i in range(3): with gr.Column(): gr.Markdown(f"### Chat Window {i + 1}") - api_endpoint = gr.Dropdown(label=f"API Endpoint {i + 1}", - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", - "DeepSeek", "Mistral", "OpenRouter", "Llama.cpp", "Kobold", - "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace"]) + # Refactored API selection dropdown + api_endpoint = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" + ) api_key = gr.Textbox(label=f"API Key {i + 1} (if required)", type="password") temperature = gr.Slider(label=f"Temperature {i + 1}", minimum=0.0, maximum=1.0, step=0.05, value=0.7) chatbot = gr.Chatbot(height=800, elem_classes="chat-window") + token_count_display = gr.Number(label=f"Approximate Token Count {i + 1}", value=0, + interactive=False) + token_count_displays.append(token_count_display) regenerate_button = gr.Button(f"Regenerate Last Message {i + 1}") chatbots.append(chatbot) api_endpoints.append(api_endpoint) @@ -642,16 +1016,103 @@ def create_chat_interface_multi_api(): outputs=[items_output, item_mapping] ) + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + def on_custom_prompt_checkbox_change(is_checked): + return ( + gr.update(visible=is_checked), + gr.update(visible=is_checked) + ) + + custom_prompt_checkbox.change( + fn=on_custom_prompt_checkbox_change, + inputs=[custom_prompt_checkbox], + outputs=[user_prompt, system_prompt] + ) + + def clear_all_chats(): + return [[]] * 3 + [[]] * 3 + [0] * 3 + + clear_chat_button.click( + clear_all_chats, + outputs=chatbots + chat_history + token_count_displays + ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt.change(update_user_prompt, inputs=preset_prompt, outputs=user_prompt) + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, + total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected + preset_prompt.change( + update_prompts, + inputs=[preset_prompt], + outputs=[user_prompt, system_prompt] + ) def clear_all_chats(): - return [[]] * 3 + [[]] * 3 + return [[]] * 3 + [[]] * 3 + [0] * 3 clear_chat_button.click( clear_all_chats, - outputs=chatbots + chat_history + outputs=chatbots + chat_history + token_count_displays ) + def chat_wrapper_multi(message, custom_prompt, system_prompt, *args): chat_histories = args[:3] chatbots = args[3:6] @@ -681,6 +1142,11 @@ def create_chat_interface_multi_api(): return [gr.update(value="")] + new_chatbots + new_chat_histories + def update_token_counts(*histories): + token_counts = [] + for history in histories: + token_counts.append(approximate_token_count(history)) + return token_counts def regenerate_last_message(chat_history, chatbot, media_content, selected_parts, api_endpoint, api_key, custom_prompt, temperature, system_prompt): if not chat_history: @@ -717,8 +1183,13 @@ def create_chat_interface_multi_api(): for i in range(3): regenerate_buttons[i].click( regenerate_last_message, - inputs=[chat_history[i], chatbots[i], media_content, selected_parts, api_endpoints[i], api_keys[i], user_prompt, temperatures[i], system_prompt], + inputs=[chat_history[i], chatbots[i], media_content, selected_parts, api_endpoints[i], api_keys[i], + user_prompt, temperatures[i], system_prompt], outputs=[chatbots[i], chat_history[i], gr.Textbox(label=f"Regenerate Status {i + 1}")] + ).then( + lambda history: approximate_token_count(history), + inputs=[chat_history[i]], + outputs=[token_count_displays[i]] ) # In the create_chat_interface_multi_api function: @@ -731,6 +1202,10 @@ def create_chat_interface_multi_api(): ).then( lambda: (gr.update(value=""), gr.update(value="")), outputs=[msg, user_prompt] + ).then( + update_token_counts, + inputs=chat_history, + outputs=token_count_displays ) items_output.change( @@ -747,8 +1222,17 @@ def create_chat_interface_multi_api(): ) - def create_chat_interface_four(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None custom_css = """ .chatbot-container .message-wrap .message { font-size: 14px !important; @@ -762,17 +1246,32 @@ def create_chat_interface_four(): with gr.TabItem("Four Independent API Chats", visible=True): gr.Markdown("# Four Independent API Chat Interfaces") + # Initialize prompts during component creation + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + current_page_state = gr.State(value=current_page) + total_pages_state = gr.State(value=total_pages) + page_display_text = f"Page {current_page} of {total_pages}" + with gr.Row(): with gr.Column(): preset_prompt = gr.Dropdown( - label="Select Preset Prompt", - choices=load_preset_prompts(), + label="Select Preset Prompt (This will be prefixed to your messages, recommend copy/pasting and then clearing the User Prompt box)", + choices=prompts, visible=True ) + prev_page_button = gr.Button("Previous Page", visible=True) + page_display = gr.Markdown(page_display_text, visible=True) + next_page_button = gr.Button("Next Page", visible=True) user_prompt = gr.Textbox( - label="Modify Prompt", + label="Modify User Prompt", lines=3 ) + system_prompt = gr.Textbox( + label="System Prompt", + value="You are a helpful AI assistant.", + lines=3 + ) + with gr.Column(): gr.Markdown("Scroll down for the chat windows...") @@ -781,13 +1280,11 @@ def create_chat_interface_four(): def create_single_chat_interface(index, user_prompt_component): with gr.Column(): gr.Markdown(f"### Chat Window {index + 1}") + # Refactored API selection dropdown api_endpoint = gr.Dropdown( - label=f"API Endpoint {index + 1}", - choices=[ - "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", - "DeepSeek", "Mistral", "OpenRouter", "Llama.cpp", "Kobold", - "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace" - ] + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Interaction (Optional)" ) api_key = gr.Textbox( label=f"API Key {index + 1} (if required)", @@ -804,6 +1301,8 @@ def create_chat_interface_four(): msg = gr.Textbox(label=f"Enter your message for Chat {index + 1}") submit = gr.Button(f"Submit to Chat {index + 1}") regenerate_button = gr.Button(f"Regenerate Last Message {index + 1}") + token_count_display = gr.Number(label=f"Approximate Token Count {index + 1}", value=0, + interactive=False) clear_chat_button = gr.Button(f"Clear Chat {index + 1}") # State to maintain chat history @@ -819,7 +1318,8 @@ def create_chat_interface_four(): 'submit': submit, 'regenerate_button': regenerate_button, 'clear_chat_button': clear_chat_button, - 'chat_history': chat_history + 'chat_history': chat_history, + 'token_count_display': token_count_display }) # Create four chat interfaces arranged in a 2x2 grid @@ -830,10 +1330,47 @@ def create_chat_interface_four(): create_single_chat_interface(i * 2 + j, user_prompt) # Update user_prompt based on preset_prompt selection + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return gr.update(value=prompts["user_prompt"]), gr.update(value=prompts["system_prompt"]) + preset_prompt.change( - fn=update_user_prompt, - inputs=preset_prompt, - outputs=user_prompt + fn=update_prompts, + inputs=[preset_prompt], + outputs=[user_prompt, system_prompt] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def chat_wrapper_single(message, chat_history, api_endpoint, api_key, temperature, user_prompt): @@ -913,6 +1450,10 @@ def create_chat_interface_four(): interface['chatbot'], interface['chat_history'] ] + ).then( + lambda history: approximate_token_count(history), + inputs=[interface['chat_history']], + outputs=[interface['token_count_display']] ) interface['regenerate_button'].click( @@ -929,12 +1470,18 @@ def create_chat_interface_four(): interface['chat_history'], gr.Textbox(label="Regenerate Status") ] + ).then( + lambda history: approximate_token_count(history), + inputs=[interface['chat_history']], + outputs=[interface['token_count_display']] ) + def clear_chat_single(): + return [], [], 0 + interface['clear_chat_button'].click( clear_chat_single, - inputs=[], - outputs=[interface['chatbot'], interface['chat_history']] + outputs=[interface['chatbot'], interface['chat_history'], interface['token_count_display']] ) @@ -953,233 +1500,11 @@ def chat_wrapper_single(message, chat_history, chatbot, api_endpoint, api_key, t return new_msg, updated_chatbot, new_history, new_conv_id - -# FIXME - Finish implementing functions + testing/valdidation -def create_chat_management_tab(): - with gr.TabItem("Chat Management", visible=True): - gr.Markdown("# Chat Management") - - with gr.Row(): - search_query = gr.Textbox(label="Search Conversations") - search_button = gr.Button("Search") - - conversation_list = gr.Dropdown(label="Select Conversation", choices=[]) - conversation_mapping = gr.State({}) - - with gr.Tabs(): - with gr.TabItem("Edit", visible=True): - chat_content = gr.TextArea(label="Chat Content (JSON)", lines=20, max_lines=50) - save_button = gr.Button("Save Changes") - delete_button = gr.Button("Delete Conversation", variant="stop") - - with gr.TabItem("Preview", visible=True): - chat_preview = gr.HTML(label="Chat Preview") - result_message = gr.Markdown("") - - def search_conversations(query): - conversations = search_chat_conversations(query) - choices = [f"{conv['conversation_name']} (Media: {conv['media_title']}, ID: {conv['id']})" for conv in - conversations] - mapping = {choice: conv['id'] for choice, conv in zip(choices, conversations)} - return gr.update(choices=choices), mapping - - def load_conversations(selected, conversation_mapping): - logging.info(f"Selected: {selected}") - logging.info(f"Conversation mapping: {conversation_mapping}") - - try: - if selected and selected in conversation_mapping: - conversation_id = conversation_mapping[selected] - messages = get_chat_messages(conversation_id) - conversation_data = { - "conversation_id": conversation_id, - "messages": messages - } - json_content = json.dumps(conversation_data, indent=2) - - # Create HTML preview - html_preview = "
" - for msg in messages: - sender_style = "background-color: #e6f3ff;" if msg[ - 'sender'] == 'user' else "background-color: #f0f0f0;" - html_preview += f"
" - html_preview += f"{msg['sender']}: {html.escape(msg['message'])}
" - html_preview += f"Timestamp: {msg['timestamp']}" - html_preview += "
" - html_preview += "
" - - logging.info("Returning json_content and html_preview") - return json_content, html_preview - else: - logging.warning("No conversation selected or not in mapping") - return "", "

No conversation selected

" - except Exception as e: - logging.error(f"Error in load_conversations: {str(e)}") - return f"Error: {str(e)}", "

Error loading conversation

" - - def validate_conversation_json(content): - try: - data = json.loads(content) - if not isinstance(data, dict): - return False, "Invalid JSON structure: root should be an object" - if "conversation_id" not in data or not isinstance(data["conversation_id"], int): - return False, "Missing or invalid conversation_id" - if "messages" not in data or not isinstance(data["messages"], list): - return False, "Missing or invalid messages array" - for msg in data["messages"]: - if not all(key in msg for key in ["sender", "message"]): - return False, "Invalid message structure: missing required fields" - return True, data - except json.JSONDecodeError as e: - return False, f"Invalid JSON: {str(e)}" - - def save_conversation(selected, conversation_mapping, content): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before saving.", "

No changes made

" - - conversation_id = conversation_mapping[selected] - is_valid, result = validate_conversation_json(content) - - if not is_valid: - return f"Error: {result}", "

No changes made due to error

" - - conversation_data = result - if conversation_data["conversation_id"] != conversation_id: - return "Error: Conversation ID mismatch.", "

No changes made due to ID mismatch

" - - try: - with db.get_connection() as conn: - conn.execute("BEGIN TRANSACTION") - cursor = conn.cursor() - - # Backup original conversation - cursor.execute("SELECT * FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - original_messages = cursor.fetchall() - backup_data = json.dumps({"conversation_id": conversation_id, "messages": original_messages}) - - # You might want to save this backup_data somewhere - - # Delete existing messages - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Insert updated messages - for message in conversation_data["messages"]: - cursor.execute(''' - INSERT INTO ChatMessages (conversation_id, sender, message, timestamp) - VALUES (?, ?, ?, COALESCE(?, CURRENT_TIMESTAMP)) - ''', (conversation_id, message["sender"], message["message"], message.get("timestamp"))) - - conn.commit() - - # Create updated HTML preview - html_preview = "
" - for msg in conversation_data["messages"]: - sender_style = "background-color: #e6f3ff;" if msg[ - 'sender'] == 'user' else "background-color: #f0f0f0;" - html_preview += f"
" - html_preview += f"{msg['sender']}: {html.escape(msg['message'])}
" - html_preview += f"Timestamp: {msg.get('timestamp', 'N/A')}" - html_preview += "
" - html_preview += "
" - - return "Conversation updated successfully.", html_preview - except sqlite3.Error as e: - conn.rollback() - logging.error(f"Database error in save_conversation: {e}") - return f"Error updating conversation: {str(e)}", "

Error occurred while saving

" - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in save_conversation: {e}") - return f"Unexpected error: {str(e)}", "

Unexpected error occurred

" - - def delete_conversation(selected, conversation_mapping): - if not selected or selected not in conversation_mapping: - return "Please select a conversation before deleting.", "

No changes made

", gr.update(choices=[]) - - conversation_id = conversation_mapping[selected] - - try: - with db.get_connection() as conn: - cursor = conn.cursor() - - # Delete messages associated with the conversation - cursor.execute("DELETE FROM ChatMessages WHERE conversation_id = ?", (conversation_id,)) - - # Delete the conversation itself - cursor.execute("DELETE FROM ChatConversations WHERE id = ?", (conversation_id,)) - - conn.commit() - - # Update the conversation list - remaining_conversations = [choice for choice in conversation_mapping.keys() if choice != selected] - updated_mapping = {choice: conversation_mapping[choice] for choice in remaining_conversations} - - return "Conversation deleted successfully.", "

Conversation deleted

", gr.update(choices=remaining_conversations) - except sqlite3.Error as e: - conn.rollback() - logging.error(f"Database error in delete_conversation: {e}") - return f"Error deleting conversation: {str(e)}", "

Error occurred while deleting

", gr.update() - except Exception as e: - conn.rollback() - logging.error(f"Unexpected error in delete_conversation: {e}") - return f"Unexpected error: {str(e)}", "

Unexpected error occurred

", gr.update() - - def parse_formatted_content(formatted_content): - lines = formatted_content.split('\n') - conversation_id = int(lines[0].split(': ')[1]) - timestamp = lines[1].split(': ')[1] - history = [] - current_role = None - current_content = None - for line in lines[3:]: - if line.startswith("Role: "): - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - current_role = line.split(': ')[1] - elif line.startswith("Content: "): - current_content = line.split(': ', 1)[1] - if current_role is not None: - history.append({"role": current_role, "content": ["", current_content]}) - return json.dumps({ - "conversation_id": conversation_id, - "timestamp": timestamp, - "history": history - }, indent=2) - - search_button.click( - search_conversations, - inputs=[search_query], - outputs=[conversation_list, conversation_mapping] - ) - - conversation_list.change( - load_conversations, - inputs=[conversation_list, conversation_mapping], - outputs=[chat_content, chat_preview] - ) - - save_button.click( - save_conversation, - inputs=[conversation_list, conversation_mapping, chat_content], - outputs=[result_message, chat_preview] - ) - - delete_button.click( - delete_conversation, - inputs=[conversation_list, conversation_mapping], - outputs=[result_message, chat_preview, conversation_list] - ) - - return search_query, search_button, conversation_list, conversation_mapping, chat_content, save_button, delete_button, result_message, chat_preview - - - # Mock function to simulate LLM processing def process_with_llm(workflow, context, prompt, api_endpoint, api_key): api_key_snippet = api_key[:5] + "..." if api_key else "Not provided" return f"LLM output using {api_endpoint} (API Key: {api_key_snippet}) for {workflow} with context: {context[:30]}... and prompt: {prompt[:30]}..." - # # End of Chat_ui.py ####################################################################################################################### \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py index 3f4841f9c8b52b50bcc643ed7239c123f33dd003..8c49a86a0f91d2c35b5ef841d2148c32551db7ee 100644 --- a/App_Function_Libraries/Gradio_UI/Embeddings_tab.py +++ b/App_Function_Libraries/Gradio_UI/Embeddings_tab.py @@ -4,6 +4,7 @@ # Imports import json import logging +import os # # External Imports import gradio as gr @@ -11,26 +12,58 @@ import numpy as np from tqdm import tqdm # # Local Imports -from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database +from App_Function_Libraries.DB.DB_Manager import get_all_content_from_database, get_all_conversations, \ + get_conversation_text, get_note_by_id +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_all_notes from App_Function_Libraries.RAG.ChromaDB_Library import chroma_client, \ store_in_chroma, situate_context from App_Function_Libraries.RAG.Embeddings_Create import create_embedding, create_embeddings_batch from App_Function_Libraries.Chunk_Lib import improved_chunking_process, chunk_for_embedding +from App_Function_Libraries.Utils.Utils import load_and_log_configs + + # ######################################################################################################################## # # Functions: def create_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_qa.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "chatDB.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("Create Embeddings", visible=True): gr.Markdown("# Create Embeddings for All Content") with gr.Row(): with gr.Column(): + # Database selection at the top + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to create embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + embedding_provider = gr.Radio( choices=["huggingface", "local", "openai"], label="Select Embedding Provider", - value="huggingface" + value=config['embedding_config']['embedding_provider'] or "huggingface" ) gr.Markdown("Note: Local provider requires a running Llama.cpp/llamafile server.") gr.Markdown("OpenAI provider requires a valid API key.") @@ -65,22 +98,24 @@ def create_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) - # Add chunking options + # Add chunking options with config defaults chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", value="words" ) max_chunk_size = gr.Slider( - minimum=1, maximum=8000, step=1, value=500, + minimum=1, maximum=8000, step=1, + value=config['embedding_config']['chunk_size'], label="Max Chunk Size" ) chunk_overlap = gr.Slider( - minimum=0, maximum=4000, step=1, value=200, + minimum=0, maximum=4000, step=1, + value=config['embedding_config']['overlap'], label="Chunk Overlap" ) adaptive_chunking = gr.Checkbox( @@ -92,6 +127,7 @@ def create_embeddings_tab(): with gr.Column(): status_output = gr.Textbox(label="Status", lines=10) + progress = gr.Progress() def update_provider_options(provider): if provider == "huggingface": @@ -107,23 +143,54 @@ def create_embeddings_tab(): else: return gr.update(visible=False) - embedding_provider.change( - fn=update_provider_options, - inputs=[embedding_provider], - outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] - ) - - huggingface_model.change( - fn=update_huggingface_options, - inputs=[huggingface_model], - outputs=[custom_embedding_model] - ) + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path - def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, max_size, overlap, adaptive): + def create_all_embeddings(provider, hf_model, openai_model, custom_model, api_url, method, + max_size, overlap, adaptive, database_type, progress=gr.Progress()): try: - all_content = get_all_content_from_database() + # Initialize content based on database selection + if database_type == "Media DB": + all_content = get_all_content_from_database() + content_type = "media" + elif database_type == "RAG Chat": + all_content = [] + page = 1 + while True: + conversations, total_pages, _ = get_all_conversations(page=page) + if not conversations: + break + all_content.extend([{ + 'id': conv['conversation_id'], + 'content': get_conversation_text(conv['conversation_id']), + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations]) + progress(page / total_pages, desc=f"Loading conversations... Page {page}/{total_pages}") + page += 1 + else: # Character Chat + all_content = [] + page = 1 + while True: + notes, total_pages, _ = get_all_notes(page=page) + if not notes: + break + all_content.extend([{ + 'id': note['id'], + 'content': f"{note['title']}\n\n{note['content']}", + 'conversation_id': note['conversation_id'], + 'type': 'note' + } for note in notes]) + progress(page / total_pages, desc=f"Loading notes... Page {page}/{total_pages}") + page += 1 + if not all_content: - return "No content found in the database." + return "No content found in the selected database." chunk_options = { 'method': method, @@ -132,7 +199,7 @@ def create_embeddings_tab(): 'adaptive': adaptive } - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Determine the model to use @@ -141,55 +208,113 @@ def create_embeddings_tab(): elif provider == "openai": model = openai_model else: - model = custom_model + model = api_url + + total_items = len(all_content) + for idx, item in enumerate(all_content): + progress((idx + 1) / total_items, desc=f"Processing item {idx + 1} of {total_items}") - for item in all_content: - media_id = item['id'] + content_id = item['id'] text = item['content'] chunks = improved_chunking_process(text, chunk_options) - for i, chunk in enumerate(chunks): + for chunk_idx, chunk in enumerate(chunks): chunk_text = chunk['text'] - chunk_id = f"doc_{media_id}_chunk_{i}" - - existing = collection.get(ids=[chunk_id]) - if existing['ids']: + chunk_id = f"{database_type.lower()}_{content_id}_chunk_{chunk_idx}" + + try: + embedding = create_embedding(chunk_text, provider, model, api_url) + metadata = { + 'content_id': str(content_id), + 'chunk_index': int(chunk_idx), + 'total_chunks': int(len(chunks)), + 'chunking_method': method, + 'max_chunk_size': int(max_size), + 'chunk_overlap': int(overlap), + 'adaptive_chunking': bool(adaptive), + 'embedding_model': model, + 'embedding_provider': provider, + 'content_type': item.get('type', 'media'), + 'conversation_id': item.get('conversation_id'), + **{k: (int(v) if isinstance(v, str) and v.isdigit() else v) + for k, v in chunk['metadata'].items()} + } + store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) + + except Exception as e: + logging.error(f"Error processing chunk {chunk_id}: {str(e)}") continue - embedding = create_embedding(chunk_text, provider, model, api_url) - metadata = { - "media_id": str(media_id), - "chunk_index": i, - "total_chunks": len(chunks), - "chunking_method": method, - "max_chunk_size": max_size, - "chunk_overlap": overlap, - "adaptive_chunking": adaptive, - "embedding_model": model, - "embedding_provider": provider, - **chunk['metadata'] - } - store_in_chroma(collection_name, [chunk_text], [embedding], [chunk_id], [metadata]) - - return "Embeddings created and stored successfully for all content." + return f"Embeddings created and stored successfully for all {database_type} content." except Exception as e: logging.error(f"Error during embedding creation: {str(e)}") return f"Error: {str(e)}" + # Event handlers + embedding_provider.change( + fn=update_provider_options, + inputs=[embedding_provider], + outputs=[huggingface_model, openai_model, custom_embedding_model, embedding_api_url] + ) + + huggingface_model.change( + fn=update_huggingface_options, + inputs=[huggingface_model], + outputs=[custom_embedding_model] + ) + + database_selection.change( + fn=update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + create_button.click( fn=create_all_embeddings, - inputs=[embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking], + inputs=[ + embedding_provider, huggingface_model, openai_model, custom_embedding_model, + embedding_api_url, chunking_method, max_chunk_size, chunk_overlap, + adaptive_chunking, database_selection + ], outputs=status_output ) def create_view_embeddings_tab(): + # Load configuration first + config = load_and_log_configs() + if not config: + raise ValueError("Could not load configuration") + + # Get database paths from config + db_config = config['db_config'] + media_db_path = db_config['sqlite_path'] + rag_qa_db_path = os.path.join(os.path.dirname(media_db_path), "rag_chat.db") + character_chat_db_path = os.path.join(os.path.dirname(media_db_path), "character_chat.db") + chroma_db_path = db_config['chroma_db_path'] + with gr.TabItem("View/Update Embeddings", visible=True): gr.Markdown("# View and Update Embeddings") - item_mapping = gr.State({}) + # Initialize item_mapping as a Gradio State + + with gr.Row(): with gr.Column(): + # Add database selection + database_selection = gr.Radio( + choices=["Media DB", "RAG Chat", "Character Chat"], + label="Select Content Source", + value="Media DB", + info="Choose which database to view embeddings from" + ) + + # Add database path display + current_db_path = gr.Textbox( + label="Current Database Path", + value=media_db_path, + interactive=False + ) + item_dropdown = gr.Dropdown(label="Select Item", choices=[], interactive=True) refresh_button = gr.Button("Refresh Item List") embedding_status = gr.Textbox(label="Embedding Status", interactive=False) @@ -236,9 +361,10 @@ def create_view_embeddings_tab(): embedding_api_url = gr.Textbox( label="API URL (for local provider)", - value="http://localhost:8080/embedding", + value=config['embedding_config']['embedding_api_url'], visible=False ) + chunking_method = gr.Dropdown( choices=["words", "sentences", "paragraphs", "tokens", "semantic"], label="Chunking Method", @@ -267,15 +393,45 @@ def create_view_embeddings_tab(): ) contextual_api_key = gr.Textbox(label="API Key", lines=1) - def get_items_with_embedding_status(): + item_mapping = gr.State(value={}) + + def update_database_path(database_type): + if database_type == "Media DB": + return media_db_path + elif database_type == "RAG Chat": + return rag_qa_db_path + else: # Character Chat + return character_chat_db_path + + def get_items_with_embedding_status(database_type): try: - items = get_all_content_from_database() - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + # Get items based on database selection + if database_type == "Media DB": + items = get_all_content_from_database() + elif database_type == "RAG Chat": + conversations, _, _ = get_all_conversations(page=1) + items = [{ + 'id': conv['conversation_id'], + 'title': conv['title'], + 'type': 'conversation' + } for conv in conversations] + else: # Character Chat + notes, _, _ = get_all_notes(page=1) + items = [{ + 'id': note['id'], + 'title': note['title'], + 'type': 'note' + } for note in notes] + + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + choices = [] new_item_mapping = {} for item in items: try: - result = collection.get(ids=[f"doc_{item['id']}_chunk_0"]) + chunk_id = f"{database_type.lower()}_{item['id']}_chunk_0" + result = collection.get(ids=[chunk_id]) embedding_exists = result is not None and result.get('ids') and len(result['ids']) > 0 status = "Embedding exists" if embedding_exists else "No embedding" except Exception as e: @@ -303,40 +459,62 @@ def create_view_embeddings_tab(): else: return gr.update(visible=False) - def check_embedding_status(selected_item, item_mapping): + def check_embedding_status(selected_item, database_type, item_mapping): if not selected_item: return "Please select an item", "", "" + if item_mapping is None: + # If mapping is None, try to refresh it + try: + _, item_mapping = get_items_with_embedding_status(database_type) + except Exception as e: + return f"Error initializing item mapping: {str(e)}", "", "" + try: item_id = item_mapping.get(selected_item) if item_id is None: return f"Invalid item selected: {selected_item}", "", "" item_title = selected_item.rsplit(' (', 1)[0] - collection = chroma_client.get_or_create_collection(name="all_content_embeddings") + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" + collection = chroma_client.get_or_create_collection(name=collection_name) + chunk_id = f"{database_type.lower()}_{item_id}_chunk_0" + + try: + result = collection.get(ids=[chunk_id], include=["embeddings", "metadatas"]) + except Exception as e: + logging.error(f"ChromaDB get error: {str(e)}") + return f"Error retrieving embedding for '{item_title}': {str(e)}", "", "" - result = collection.get(ids=[f"doc_{item_id}_chunk_0"], include=["embeddings", "metadatas"]) - logging.info(f"ChromaDB result for item '{item_title}' (ID: {item_id}): {result}") + # Check if result exists and has the expected structure + if not result or not isinstance(result, dict): + return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['ids']: + # Check if we have any results + if not result.get('ids') or len(result['ids']) == 0: return f"No embedding found for item '{item_title}' (ID: {item_id})", "", "" - if not result['embeddings'] or not result['embeddings'][0]: + # Check if embeddings exist + if not result.get('embeddings') or not result['embeddings'][0]: return f"Embedding data missing for item '{item_title}' (ID: {item_id})", "", "" embedding = result['embeddings'][0] - metadata = result['metadatas'][0] if result['metadatas'] else {} + metadata = result.get('metadatas', [{}])[0] if result.get('metadatas') else {} embedding_preview = str(embedding[:50]) status = f"Embedding exists for item '{item_title}' (ID: {item_id})" return status, f"First 50 elements of embedding:\n{embedding_preview}", json.dumps(metadata, indent=2) except Exception as e: - logging.error(f"Error in check_embedding_status: {str(e)}") + logging.error(f"Error in check_embedding_status: {str(e)}", exc_info=True) return f"Error processing item: {selected_item}. Details: {str(e)}", "", "" - def create_new_embedding_for_item(selected_item, provider, hf_model, openai_model, custom_model, api_url, - method, max_size, overlap, adaptive, - item_mapping, use_contextual, contextual_api_choice=None): + def refresh_and_update(database_type): + choices_update, new_mapping = get_items_with_embedding_status(database_type) + return choices_update, new_mapping + + def create_new_embedding_for_item(selected_item, database_type, provider, hf_model, openai_model, + custom_model, api_url, method, max_size, overlap, adaptive, + item_mapping, use_contextual, contextual_api_choice=None): if not selected_item: return "Please select an item", "", "" @@ -345,8 +523,26 @@ def create_view_embeddings_tab(): if item_id is None: return f"Invalid item selected: {selected_item}", "", "" - items = get_all_content_from_database() - item = next((item for item in items if item['id'] == item_id), None) + # Get item content based on database type + if database_type == "Media DB": + items = get_all_content_from_database() + item = next((item for item in items if item['id'] == item_id), None) + elif database_type == "RAG Chat": + item = { + 'id': item_id, + 'content': get_conversation_text(item_id), + 'title': selected_item.rsplit(' (', 1)[0], + 'type': 'conversation' + } + else: # Character Chat + note = get_note_by_id(item_id) + item = { + 'id': item_id, + 'content': f"{note['title']}\n\n{note['content']}", + 'title': note['title'], + 'type': 'note' + } + if not item: return f"Item not found: {item_id}", "", "" @@ -359,11 +555,11 @@ def create_view_embeddings_tab(): logging.info(f"Chunking content for item: {item['title']} (ID: {item_id})") chunks = chunk_for_embedding(item['content'], item['title'], chunk_options) - collection_name = "all_content_embeddings" + collection_name = f"{database_type.lower().replace(' ', '_')}_embeddings" collection = chroma_client.get_or_create_collection(name=collection_name) # Delete existing embeddings for this item - existing_ids = [f"doc_{item_id}_chunk_{i}" for i in range(len(chunks))] + existing_ids = [f"{database_type.lower()}_{item_id}_chunk_{i}" for i in range(len(chunks))] collection.delete(ids=existing_ids) logging.info(f"Deleted {len(existing_ids)} existing embeddings for item {item_id}") @@ -381,7 +577,7 @@ def create_view_embeddings_tab(): contextualized_text = chunk_text context = None - chunk_id = f"doc_{item_id}_chunk_{i}" + chunk_id = f"{database_type.lower()}_{item_id}_chunk_{i}" # Determine the model to use if provider == "huggingface": @@ -392,7 +588,7 @@ def create_view_embeddings_tab(): model = custom_model metadata = { - "media_id": str(item_id), + "content_id": str(item_id), "chunk_index": i, "total_chunks": len(chunks), "chunking_method": method, @@ -441,15 +637,25 @@ def create_view_embeddings_tab(): logging.error(f"Error in create_new_embedding_for_item: {str(e)}", exc_info=True) return f"Error creating embedding: {str(e)}", "", "" + # Wire up all the event handlers + database_selection.change( + update_database_path, + inputs=[database_selection], + outputs=[current_db_path] + ) + refresh_button.click( get_items_with_embedding_status, + inputs=[database_selection], outputs=[item_dropdown, item_mapping] ) + item_dropdown.change( check_embedding_status, - inputs=[item_dropdown, item_mapping], + inputs=[item_dropdown, database_selection, item_mapping], outputs=[embedding_status, embedding_preview, embedding_metadata] ) + create_new_embedding_button.click( create_new_embedding_for_item, inputs=[item_dropdown, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, @@ -469,9 +675,10 @@ def create_view_embeddings_tab(): ) return (item_dropdown, refresh_button, embedding_status, embedding_preview, embedding_metadata, - create_new_embedding_button, embedding_provider, huggingface_model, openai_model, custom_embedding_model, embedding_api_url, - chunking_method, max_chunk_size, chunk_overlap, adaptive_chunking, - use_contextual_embeddings, contextual_api_choice, contextual_api_key) + create_new_embedding_button, embedding_provider, huggingface_model, openai_model, + custom_embedding_model, embedding_api_url, chunking_method, max_chunk_size, + chunk_overlap, adaptive_chunking, use_contextual_embeddings, + contextual_api_choice, contextual_api_key) def create_purge_embeddings_tab(): diff --git a/App_Function_Libraries/Gradio_UI/Evaluations_Benchmarks_tab.py b/App_Function_Libraries/Gradio_UI/Evaluations_Benchmarks_tab.py index f1ffbc69ecbcf8786493397cf6ed45931e561f13..b9529a3002b4d536c4d8ae7e53c24bd2130bc5a2 100644 --- a/App_Function_Libraries/Gradio_UI/Evaluations_Benchmarks_tab.py +++ b/App_Function_Libraries/Gradio_UI/Evaluations_Benchmarks_tab.py @@ -1,9 +1,12 @@ ################################################################################################### # Evaluations_Benchmarks_tab.py - Gradio code for G-Eval testing # We will use the G-Eval API to evaluate the quality of the generated summaries. +import logging import gradio as gr from App_Function_Libraries.Benchmarks_Evaluations.ms_g_eval import run_geval +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name + def create_geval_tab(): with gr.Tab("G-Eval", visible=True): @@ -31,13 +34,25 @@ def create_geval_tab(): def create_infinite_bench_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.Tab("Infinite Bench", visible=True): gr.Markdown("# Infinite Bench Evaluation (Coming Soon)") with gr.Row(): with gr.Column(): + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=["OpenAI", "Anthropic", "Cohere", "Groq", "OpenRouter", "DeepSeek", "HuggingFace", "Mistral", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "Local-LLM", "Ollama"], - label="Select API" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization (Optional)" ) api_key_input = gr.Textbox(label="API Key (if required)", type="password") evaluate_button = gr.Button("Evaluate Summary") diff --git a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py index 37349d8df88886a7f67f47fbbbb175cb76893698..f75ec04dc063dceafa4dbb22b511b71fc1745f05 100644 --- a/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Explain_summarize_tab.py @@ -7,7 +7,7 @@ import logging # External Imports import gradio as gr -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt # # Local Imports @@ -17,6 +17,9 @@ from App_Function_Libraries.Summarization.Local_Summarization_Lib import summari from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_with_openai, summarize_with_anthropic, \ summarize_with_cohere, summarize_with_groq, summarize_with_openrouter, summarize_with_deepseek, \ summarize_with_huggingface +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name + + # # ############################################################################################################ @@ -24,32 +27,62 @@ from App_Function_Libraries.Summarization.Summarization_General_Lib import summa # Functions: def create_summarize_explain_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("Analyze Text", visible=True): gr.Markdown("# Analyze / Explain / Summarize Text without ingesting it into the DB") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): with gr.Row(): - text_to_work_input = gr.Textbox(label="Text to be Explained or Summarized", - placeholder="Enter the text you want explained or summarized here", - lines=20) + text_to_work_input = gr.Textbox( + label="Text to be Explained or Summarized", + placeholder="Enter the text you want explained or summarized here", + lines=20 + ) with gr.Row(): explanation_checkbox = gr.Checkbox(label="Explain Text", value=True) summarization_checkbox = gr.Checkbox(label="Summarize Text", value=True) - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=10, + visible=False + ) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] @@ -69,19 +102,21 @@ def create_summarize_explain_tab(): - Ensure adherence to specified format - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] """, - lines=3, + lines=10, visible=False, interactive=True) + # Refactored API selection dropdown api_endpoint = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", "Custom-OpenAI-API"], - value=None, - label="API to be used for request (Mandatory)" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" ) with gr.Row(): - api_key_input = gr.Textbox(label="API Key (if required)", placeholder="Enter your API key here", - type="password") + api_key_input = gr.Textbox( + label="API Key (if required)", + placeholder="Enter your API key here", + type="password" + ) with gr.Row(): explain_summarize_button = gr.Button("Explain/Summarize") @@ -90,17 +125,83 @@ def create_summarize_explain_tab(): explanation_output = gr.Textbox(label="Explanation:", lines=20) custom_prompt_output = gr.Textbox(label="Custom Prompt:", lines=20, visible=True) + # Handle custom prompt checkbox change custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[ + preset_prompt, + prev_page_button, + next_page_button, + page_display, + current_page_state, + total_pages_state + ] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -109,18 +210,27 @@ def create_summarize_explain_tab(): ) preset_prompt.change( - update_prompts, - inputs=preset_prompt, + fn=update_prompts, + inputs=[preset_prompt], outputs=[custom_prompt_input, system_prompt_input] ) explain_summarize_button.click( fn=summarize_explain_text, - inputs=[text_to_work_input, api_endpoint, api_key_input, summarization_checkbox, explanation_checkbox, custom_prompt_input, system_prompt_input], + inputs=[ + text_to_work_input, + api_endpoint, + api_key_input, + summarization_checkbox, + explanation_checkbox, + custom_prompt_input, + system_prompt_input + ], outputs=[summarization_output, explanation_output, custom_prompt_output] ) + def summarize_explain_text(message, api_endpoint, api_key, summarization, explanation, custom_prompt, custom_system_prompt,): global custom_prompt_output summarization_response = None diff --git a/App_Function_Libraries/Gradio_UI/Export_Functionality.py b/App_Function_Libraries/Gradio_UI/Export_Functionality.py index 2feed8605a614624f6b6246e6379dd7582e15240..dba3e88556b5ec3746d5f06930892d80b056a37c 100644 --- a/App_Function_Libraries/Gradio_UI/Export_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Export_Functionality.py @@ -6,9 +6,11 @@ import math import logging import shutil import tempfile -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Any import gradio as gr -from App_Function_Libraries.DB.DB_Manager import DatabaseError +from App_Function_Libraries.DB.DB_Manager import DatabaseError, fetch_all_notes, fetch_all_conversations, \ + get_keywords_for_note, fetch_notes_by_ids, fetch_conversations_by_ids +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_keywords_for_conversation from App_Function_Libraries.Gradio_UI.Gradio_Shared import fetch_item_details, fetch_items_by_keyword, browse_items logger = logging.getLogger(__name__) @@ -36,7 +38,7 @@ def export_items_by_keyword(keyword: str) -> str: items = fetch_items_by_keyword(keyword) if not items: logger.warning(f"No items found for keyword: {keyword}") - return None + return f"No items found for keyword: {keyword}" # Create a temporary directory to store individual markdown files with tempfile.TemporaryDirectory() as temp_dir: @@ -66,7 +68,7 @@ def export_items_by_keyword(keyword: str) -> str: return final_zip_path except Exception as e: logger.error(f"Error exporting items for keyword '{keyword}': {str(e)}") - return None + return f"Error exporting items for keyword '{keyword}': {str(e)}" def export_selected_items(selected_items: List[Dict]) -> Tuple[Optional[str], str]: @@ -146,121 +148,747 @@ def display_search_results_export_tab(search_query: str, search_type: str, page: logger.error(error_message) return [], error_message, 1, 1 +# +# End of Media DB Export functionality +################################################################ -def create_export_tab(): - with gr.Tab("Search and Export"): - with gr.Row(): - with gr.Column(): - gr.Markdown("# Search and Export Items") - gr.Markdown("Search for items and export them as markdown files") - gr.Markdown("You can also export items by keyword") - search_query = gr.Textbox(label="Search Query") - search_type = gr.Radio(["Title", "URL", "Keyword", "Content"], label="Search By") - search_button = gr.Button("Search") - - with gr.Column(): - prev_button = gr.Button("Previous Page") - next_button = gr.Button("Next Page") - - current_page = gr.State(1) - total_pages = gr.State(1) - - search_results = gr.CheckboxGroup(label="Search Results", choices=[]) - export_selected_button = gr.Button("Export Selected Items") - - keyword_input = gr.Textbox(label="Enter keyword for export") - export_by_keyword_button = gr.Button("Export items by keyword") - - export_output = gr.File(label="Download Exported File") - error_output = gr.Textbox(label="Status/Error Messages", interactive=False) - - def search_and_update(query, search_type, page): - results, message, current, total = display_search_results_export_tab(query, search_type, page) - logger.debug(f"search_and_update results: {results}") - return results, message, current, total, gr.update(choices=results) - - search_button.click( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages, search_results], - show_progress="full" - ) - - - def update_page(current, total, direction): - new_page = max(1, min(total, current + direction)) - return new_page - - prev_button.click( - fn=update_page, - inputs=[current_page, total_pages, gr.State(-1)], - outputs=[current_page] - ).then( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages], - show_progress=True - ) - - next_button.click( - fn=update_page, - inputs=[current_page, total_pages, gr.State(1)], - outputs=[current_page] - ).then( - fn=search_and_update, - inputs=[search_query, search_type, current_page], - outputs=[search_results, error_output, current_page, total_pages], - show_progress=True - ) - - def handle_export_selected(selected_items): - logger.debug(f"Exporting selected items: {selected_items}") - return export_selected_items(selected_items) - - export_selected_button.click( - fn=handle_export_selected, - inputs=[search_results], - outputs=[export_output, error_output], - show_progress="full" - ) - - export_by_keyword_button.click( - fn=export_items_by_keyword, - inputs=[keyword_input], - outputs=[export_output, error_output], - show_progress="full" - ) - - def handle_item_selection(selected_items): - logger.debug(f"Selected items: {selected_items}") - if not selected_items: - return None, "No item selected" - - try: - # Assuming selected_items is a list of dictionaries - selected_item = selected_items[0] - logger.debug(f"First selected item: {selected_item}") - - # Check if 'value' is a string (JSON) or already a dictionary - if isinstance(selected_item['value'], str): - item_data = json.loads(selected_item['value']) - else: - item_data = selected_item['value'] - - logger.debug(f"Item data: {item_data}") - - item_id = item_data['id'] - return export_item_as_markdown(item_id) - except Exception as e: - error_message = f"Error processing selected item: {str(e)}" - logger.error(error_message) - return None, error_message - - search_results.select( - fn=handle_item_selection, - inputs=[search_results], - outputs=[export_output, error_output], - show_progress="full" - ) +################################################################ +# +# Functions for RAG Chat DB Export functionality + + +def export_rag_conversations_as_json( + selected_conversations: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Optional[str], str]: + """ + Export conversations to a JSON file. + + Args: + selected_conversations: Optional list of conversation dictionaries + + Returns: + Tuple of (filename or None, status message) + """ + try: + if selected_conversations: + # Extract conversation IDs from selected items + conversation_ids = [] + for item in selected_conversations: + if isinstance(item, str): + item_data = json.loads(item) + elif isinstance(item, dict) and 'value' in item: + item_data = item['value'] if isinstance(item['value'], dict) else json.loads(item['value']) + else: + item_data = item + conversation_ids.append(item_data['conversation_id']) + + conversations = fetch_conversations_by_ids(conversation_ids) + else: + conversations = fetch_all_conversations() + + export_data = [] + for conversation_id, title, messages in conversations: + # Get keywords for the conversation + keywords = get_keywords_for_conversation(conversation_id) + + conversation_data = { + "conversation_id": conversation_id, + "title": title, + "keywords": keywords, + "messages": [ + {"role": role, "content": content} + for role, content in messages + ] + } + export_data.append(conversation_data) + + filename = "rag_conversations_export.json" + with open(filename, "w", encoding='utf-8') as f: + json.dump(export_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Successfully exported {len(export_data)} conversations to {filename}") + return filename, f"Successfully exported {len(export_data)} conversations to {filename}" + except Exception as e: + error_message = f"Error exporting conversations: {str(e)}" + logger.error(error_message) + return None, error_message + + +def export_rag_notes_as_json( + selected_notes: Optional[List[Dict[str, Any]]] = None +) -> Tuple[Optional[str], str]: + """ + Export notes to a JSON file. + + Args: + selected_notes: Optional list of note dictionaries + + Returns: + Tuple of (filename or None, status message) + """ + try: + if selected_notes: + # Extract note IDs from selected items + note_ids = [] + for item in selected_notes: + if isinstance(item, str): + item_data = json.loads(item) + elif isinstance(item, dict) and 'value' in item: + item_data = item['value'] if isinstance(item['value'], dict) else json.loads(item['value']) + else: + item_data = item + note_ids.append(item_data['id']) + + notes = fetch_notes_by_ids(note_ids) + else: + notes = fetch_all_notes() + + export_data = [] + for note_id, title, content in notes: + # Get keywords for the note + keywords = get_keywords_for_note(note_id) + + note_data = { + "note_id": note_id, + "title": title, + "content": content, + "keywords": keywords + } + export_data.append(note_data) + + filename = "rag_notes_export.json" + with open(filename, "w", encoding='utf-8') as f: + json.dump(export_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Successfully exported {len(export_data)} notes to {filename}") + return filename, f"Successfully exported {len(export_data)} notes to {filename}" + except Exception as e: + error_message = f"Error exporting notes: {str(e)}" + logger.error(error_message) + return None, error_message + + +def display_rag_conversations(search_query: str = "", page: int = 1, items_per_page: int = 10): + """Display conversations for selection in the export tab.""" + try: + conversations = fetch_all_conversations() + + if search_query: + # Simple search implementation - can be enhanced based on needs + conversations = [ + conv for conv in conversations + if search_query.lower() in conv[1].lower() # Search in title + ] + + # Implement pagination + start_idx = (page - 1) * items_per_page + end_idx = start_idx + items_per_page + paginated_conversations = conversations[start_idx:end_idx] + total_pages = (len(conversations) + items_per_page - 1) // items_per_page + + # Format for checkbox group + checkbox_data = [ + { + "name": f"Title: {title}\nMessages: {len(messages)}", + "value": {"conversation_id": conv_id, "title": title} + } + for conv_id, title, messages in paginated_conversations + ] + + return ( + checkbox_data, + f"Found {len(conversations)} conversations (showing page {page} of {total_pages})", + page, + total_pages + ) + except Exception as e: + error_message = f"Error displaying conversations: {str(e)}" + logger.error(error_message) + return [], error_message, 1, 1 + + +def display_rag_notes(search_query: str = "", page: int = 1, items_per_page: int = 10): + """Display notes for selection in the export tab.""" + try: + notes = fetch_all_notes() + + if search_query: + # Simple search implementation - can be enhanced based on needs + notes = [ + note for note in notes + if search_query.lower() in note[1].lower() # Search in title + or search_query.lower() in note[2].lower() # Search in content + ] + + # Implement pagination + start_idx = (page - 1) * items_per_page + end_idx = start_idx + items_per_page + paginated_notes = notes[start_idx:end_idx] + total_pages = (len(notes) + items_per_page - 1) // items_per_page + + # Format for checkbox group + checkbox_data = [ + { + "name": f"Title: {title}\nContent preview: {content[:100]}...", + "value": {"id": note_id, "title": title} + } + for note_id, title, content in paginated_notes + ] + + return ( + checkbox_data, + f"Found {len(notes)} notes (showing page {page} of {total_pages})", + page, + total_pages + ) + except Exception as e: + error_message = f"Error displaying notes: {str(e)}" + logger.error(error_message) + return [], error_message, 1, 1 + + +def create_rag_export_tab(): + """Create the RAG QA Chat export tab interface.""" + with gr.Tab("RAG QA Chat Export"): + with gr.Tabs(): + # Conversations Export Tab + with gr.Tab("Export Conversations"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Conversations") + conversation_search = gr.Textbox(label="Search Conversations") + conversation_search_button = gr.Button("Search") + + with gr.Column(): + conversation_prev_button = gr.Button("Previous Page") + conversation_next_button = gr.Button("Next Page") + + conversation_current_page = gr.State(1) + conversation_total_pages = gr.State(1) + + conversation_results = gr.CheckboxGroup(label="Select Conversations to Export") + export_selected_conversations_button = gr.Button("Export Selected Conversations") + export_all_conversations_button = gr.Button("Export All Conversations") + + conversation_export_output = gr.File(label="Download Exported Conversations") + conversation_status = gr.Textbox(label="Status", interactive=False) + + # Notes Export Tab + with gr.Tab("Export Notes"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Notes") + notes_search = gr.Textbox(label="Search Notes") + notes_search_button = gr.Button("Search") + + with gr.Column(): + notes_prev_button = gr.Button("Previous Page") + notes_next_button = gr.Button("Next Page") + + notes_current_page = gr.State(1) + notes_total_pages = gr.State(1) + + notes_results = gr.CheckboxGroup(label="Select Notes to Export") + export_selected_notes_button = gr.Button("Export Selected Notes") + export_all_notes_button = gr.Button("Export All Notes") + + notes_export_output = gr.File(label="Download Exported Notes") + notes_status = gr.Textbox(label="Status", interactive=False) + + # Event handlers for conversations + def search_conversations(query, page): + return display_rag_conversations(query, page) + + conversation_search_button.click( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + def update_conversation_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + conversation_prev_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(-1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_next_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + export_selected_conversations_button.click( + fn=export_rag_conversations_as_json, + inputs=[conversation_results], + outputs=[conversation_export_output, conversation_status] + ) + + export_all_conversations_button.click( + fn=lambda: export_rag_conversations_as_json(), + outputs=[conversation_export_output, conversation_status] + ) + + # Event handlers for notes + def search_notes(query, page): + return display_rag_notes(query, page) + + notes_search_button.click( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + def update_notes_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + notes_prev_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(-1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_next_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + export_selected_notes_button.click( + fn=export_rag_notes_as_json, + inputs=[notes_results], + outputs=[notes_export_output, notes_status] + ) + + export_all_notes_button.click( + fn=lambda: export_rag_notes_as_json(), + outputs=[notes_export_output, notes_status] + ) + +# +# End of RAG Chat DB Export functionality +##################################################### + +def create_export_tabs(): + """Create the unified export interface with all export tabs.""" + with gr.Tabs(): + # Media DB Export Tab + with gr.Tab("Media DB Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("# Search and Export Items") + gr.Markdown("Search for items and export them as markdown files") + gr.Markdown("You can also export items by keyword") + search_query = gr.Textbox(label="Search Query") + search_type = gr.Radio(["Title", "URL", "Keyword", "Content"], label="Search By") + search_button = gr.Button("Search") + + with gr.Column(): + prev_button = gr.Button("Previous Page") + next_button = gr.Button("Next Page") + + current_page = gr.State(1) + total_pages = gr.State(1) + + search_results = gr.CheckboxGroup(label="Search Results", choices=[]) + export_selected_button = gr.Button("Export Selected Items") + + keyword_input = gr.Textbox(label="Enter keyword for export") + export_by_keyword_button = gr.Button("Export items by keyword") + + export_output = gr.File(label="Download Exported File") + error_output = gr.Textbox(label="Status/Error Messages", interactive=False) + + # Conversations Export Tab + with gr.Tab("RAG Conversations Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Conversations") + conversation_search = gr.Textbox(label="Search Conversations") + conversation_search_button = gr.Button("Search") + + with gr.Column(): + conversation_prev_button = gr.Button("Previous Page") + conversation_next_button = gr.Button("Next Page") + + conversation_current_page = gr.State(1) + conversation_total_pages = gr.State(1) + + conversation_results = gr.CheckboxGroup(label="Select Conversations to Export") + export_selected_conversations_button = gr.Button("Export Selected Conversations") + export_all_conversations_button = gr.Button("Export All Conversations") + + conversation_export_output = gr.File(label="Download Exported Conversations") + conversation_status = gr.Textbox(label="Status", interactive=False) + + # Notes Export Tab + with gr.Tab("RAG Notes Export"): + with gr.Row(): + with gr.Column(): + gr.Markdown("## Export RAG QA Chat Notes") + notes_search = gr.Textbox(label="Search Notes") + notes_search_button = gr.Button("Search") + + with gr.Column(): + notes_prev_button = gr.Button("Previous Page") + notes_next_button = gr.Button("Next Page") + + notes_current_page = gr.State(1) + notes_total_pages = gr.State(1) + + notes_results = gr.CheckboxGroup(label="Select Notes to Export") + export_selected_notes_button = gr.Button("Export Selected Notes") + export_all_notes_button = gr.Button("Export All Notes") + + notes_export_output = gr.File(label="Download Exported Notes") + notes_status = gr.Textbox(label="Status", interactive=False) + + # Event handlers for media DB + def search_and_update(query, search_type, page): + results, message, current, total = display_search_results_export_tab(query, search_type, page) + logger.debug(f"search_and_update results: {results}") + return results, message, current, total, gr.update(choices=results) + + def update_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + def handle_export_selected(selected_items): + logger.debug(f"Exporting selected items: {selected_items}") + return export_selected_items(selected_items) + + def handle_item_selection(selected_items): + logger.debug(f"Selected items: {selected_items}") + if not selected_items: + return None, "No item selected" + + try: + selected_item = selected_items[0] + logger.debug(f"First selected item: {selected_item}") + + if isinstance(selected_item['value'], str): + item_data = json.loads(selected_item['value']) + else: + item_data = selected_item['value'] + + logger.debug(f"Item data: {item_data}") + item_id = item_data['id'] + return export_item_as_markdown(item_id) + except Exception as e: + error_message = f"Error processing selected item: {str(e)}" + logger.error(error_message) + return None, error_message + + search_button.click( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages, search_results], + show_progress="full" + ) + + prev_button.click( + fn=update_page, + inputs=[current_page, total_pages, gr.State(-1)], + outputs=[current_page] + ).then( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages], + show_progress=True + ) + + next_button.click( + fn=update_page, + inputs=[current_page, total_pages, gr.State(1)], + outputs=[current_page] + ).then( + fn=search_and_update, + inputs=[search_query, search_type, current_page], + outputs=[search_results, error_output, current_page, total_pages], + show_progress=True + ) + + export_selected_button.click( + fn=handle_export_selected, + inputs=[search_results], + outputs=[export_output, error_output], + show_progress="full" + ) + + export_by_keyword_button.click( + fn=export_items_by_keyword, + inputs=[keyword_input], + outputs=[export_output, error_output], + show_progress="full" + ) + + search_results.select( + fn=handle_item_selection, + inputs=[search_results], + outputs=[export_output, error_output], + show_progress="full" + ) + + # Event handlers for conversations + def search_conversations(query, page): + return display_rag_conversations(query, page) + + def update_conversation_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + conversation_search_button.click( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_prev_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(-1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + conversation_next_button.click( + fn=update_conversation_page, + inputs=[conversation_current_page, conversation_total_pages, gr.State(1)], + outputs=[conversation_current_page] + ).then( + fn=search_conversations, + inputs=[conversation_search, conversation_current_page], + outputs=[conversation_results, conversation_status, conversation_current_page, conversation_total_pages] + ) + + export_selected_conversations_button.click( + fn=export_rag_conversations_as_json, + inputs=[conversation_results], + outputs=[conversation_export_output, conversation_status] + ) + + export_all_conversations_button.click( + fn=lambda: export_rag_conversations_as_json(), + outputs=[conversation_export_output, conversation_status] + ) + + # Event handlers for notes + def search_notes(query, page): + return display_rag_notes(query, page) + + def update_notes_page(current, total, direction): + new_page = max(1, min(total, current + direction)) + return new_page + + notes_search_button.click( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_prev_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(-1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + notes_next_button.click( + fn=update_notes_page, + inputs=[notes_current_page, notes_total_pages, gr.State(1)], + outputs=[notes_current_page] + ).then( + fn=search_notes, + inputs=[notes_search, notes_current_page], + outputs=[notes_results, notes_status, notes_current_page, notes_total_pages] + ) + + export_selected_notes_button.click( + fn=export_rag_notes_as_json, + inputs=[notes_results], + outputs=[notes_export_output, notes_status] + ) + + export_all_notes_button.click( + fn=lambda: export_rag_notes_as_json(), + outputs=[notes_export_output, notes_status] + ) + + with gr.TabItem("Export Prompts", visible=True): + gr.Markdown("# Export Prompts Database Content") + + with gr.Row(): + with gr.Column(): + export_type = gr.Radio( + choices=["All Prompts", "Prompts by Keyword"], + label="Export Type", + value="All Prompts" + ) + + # Keyword selection for filtered export + with gr.Column(visible=False) as keyword_col: + keyword_input = gr.Textbox( + label="Enter Keywords (comma-separated)", + placeholder="Enter keywords to filter prompts..." + ) + + # Export format selection + export_format = gr.Radio( + choices=["CSV", "Markdown (ZIP)"], + label="Export Format", + value="CSV" + ) + + # Export options + include_options = gr.CheckboxGroup( + choices=[ + "Include System Prompts", + "Include User Prompts", + "Include Details", + "Include Author", + "Include Keywords" + ], + label="Export Options", + value=["Include Keywords", "Include Author"] + ) + + # Markdown-specific options (only visible when Markdown is selected) + with gr.Column(visible=False) as markdown_options_col: + markdown_template = gr.Radio( + choices=[ + "Basic Template", + "Detailed Template", + "Custom Template" + ], + label="Markdown Template", + value="Basic Template" + ) + custom_template = gr.Textbox( + label="Custom Template", + placeholder="Use {title}, {author}, {details}, {system}, {user}, {keywords} as placeholders", + visible=False + ) + + export_button = gr.Button("Export Prompts") + + with gr.Column(): + export_status = gr.Textbox(label="Export Status", interactive=False) + export_file = gr.File(label="Download Export") + + def update_ui_visibility(export_type, format_choice, template_choice): + """Update UI elements visibility based on selections""" + show_keywords = export_type == "Prompts by Keyword" + show_markdown_options = format_choice == "Markdown (ZIP)" + show_custom_template = template_choice == "Custom Template" and show_markdown_options + + return [ + gr.update(visible=show_keywords), # keyword_col + gr.update(visible=show_markdown_options), # markdown_options_col + gr.update(visible=show_custom_template) # custom_template + ] + + def handle_export(export_type, keywords, export_format, options, markdown_template, custom_template): + """Handle the export process based on selected options""" + try: + # Parse options + include_system = "Include System Prompts" in options + include_user = "Include User Prompts" in options + include_details = "Include Details" in options + include_author = "Include Author" in options + include_keywords = "Include Keywords" in options + + # Handle keyword filtering + keyword_list = None + if export_type == "Prompts by Keyword" and keywords: + keyword_list = [k.strip() for k in keywords.split(",") if k.strip()] + + # Get the appropriate template + template = None + if export_format == "Markdown (ZIP)": + if markdown_template == "Custom Template": + template = custom_template + else: + template = markdown_template + + # Perform export + from App_Function_Libraries.DB.Prompts_DB import export_prompts + status, file_path = export_prompts( + export_format=export_format.split()[0].lower(), # 'csv' or 'markdown' + filter_keywords=keyword_list, + include_system=include_system, + include_user=include_user, + include_details=include_details, + include_author=include_author, + include_keywords=include_keywords, + markdown_template=template + ) + + return status, file_path + + except Exception as e: + error_msg = f"Export failed: {str(e)}" + logging.error(error_msg) + return error_msg, None + + # Event handlers + export_type.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_format.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + markdown_template.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_button.click( + fn=handle_export, + inputs=[ + export_type, + keyword_input, + export_format, + include_options, + markdown_template, + custom_template + ], + outputs=[export_status, export_file] + ) + +# +# End of Export_Functionality.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Gradio_Shared.py b/App_Function_Libraries/Gradio_UI/Gradio_Shared.py index 83925ec9d41f0d68b90729cbdfd9aa7b83b7fb10..16c1048c12319deae2c65b6f810666d91d2539e2 100644 --- a/App_Function_Libraries/Gradio_UI/Gradio_Shared.py +++ b/App_Function_Libraries/Gradio_UI/Gradio_Shared.py @@ -216,11 +216,6 @@ def format_content(content): return formatted_content -def update_prompt_dropdown(): - prompt_names = list_prompts() - return gr.update(choices=prompt_names) - - def display_prompt_details(selected_prompt): if selected_prompt: prompts = update_user_prompt(selected_prompt) diff --git a/App_Function_Libraries/Gradio_UI/Import_Functionality.py b/App_Function_Libraries/Gradio_UI/Import_Functionality.py index c748d2c866fc44f781a2a2e1c3045d7f4deff064..b701d11e42e2caa0806aa4ef8ddde7796b177923 100644 --- a/App_Function_Libraries/Gradio_UI/Import_Functionality.py +++ b/App_Function_Libraries/Gradio_UI/Import_Functionality.py @@ -2,24 +2,31 @@ # Functionality to import content into the DB # # Imports +from datetime import datetime from time import sleep import logging import re import shutil import tempfile import os +from pathlib import Path +import sqlite3 import traceback +from typing import Optional, List, Dict, Tuple +import uuid import zipfile # # External Imports import gradio as gr +from chardet import detect + # # Local Imports -from App_Function_Libraries.DB.DB_Manager import insert_prompt_to_db, load_preset_prompts, import_obsidian_note_to_db, \ - add_media_to_database +from App_Function_Libraries.DB.DB_Manager import insert_prompt_to_db, import_obsidian_note_to_db, \ + add_media_to_database, list_prompts from App_Function_Libraries.Prompt_Handling import import_prompt_from_file, import_prompts_from_zip# from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization - +# ################################################################################################################### # # Functions: @@ -203,15 +210,6 @@ def create_import_single_prompt_tab(): outputs=save_output ) - def update_prompt_dropdown(): - return gr.update(choices=load_preset_prompts()) - - save_button.click( - fn=update_prompt_dropdown, - inputs=[], - outputs=[gr.Dropdown(label="Select Preset Prompt")] - ) - def create_import_item_tab(): with gr.TabItem("Import Markdown/Text Files", visible=True): gr.Markdown("# Import a markdown file or text file into the database") @@ -250,11 +248,18 @@ def create_import_multiple_prompts_tab(): gr.Markdown("# Import multiple prompts into the database") gr.Markdown("Upload a zip file containing multiple prompt files (txt or md)") + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): zip_file = gr.File(label="Upload zip file for import", file_types=["zip"]) import_button = gr.Button("Import Prompts") prompts_dropdown = gr.Dropdown(label="Select Prompt to Edit", choices=[]) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content") author_input = gr.Textbox(label="Author", placeholder="Enter the author's name") system_input = gr.Textbox(label="System", placeholder="Enter the system message for the prompt", @@ -268,6 +273,10 @@ def create_import_multiple_prompts_tab(): save_output = gr.Textbox(label="Save Status") prompts_display = gr.Textbox(label="Identified Prompts") + # State to store imported prompts + zip_import_state = gr.State([]) + + # Function to handle zip import def handle_zip_import(zip_file): result = import_prompts_from_zip(zip_file) if isinstance(result, list): @@ -278,6 +287,13 @@ def create_import_multiple_prompts_tab(): else: return gr.update(value=result), [], gr.update(value=""), [] + import_button.click( + fn=handle_zip_import, + inputs=[zip_file], + outputs=[import_output, prompts_dropdown, prompts_display, zip_import_state] + ) + + # Function to handle prompt selection from imported prompts def handle_prompt_selection(selected_title, prompts): selected_prompt = next((prompt for prompt in prompts if prompt['title'] == selected_title), None) if selected_prompt: @@ -305,23 +321,68 @@ def create_import_multiple_prompts_tab(): outputs=[title_input, author_input, system_input, user_input, keywords_input] ) + # Function to save prompt to the database def save_prompt_to_db(title, author, system, user, keywords): keyword_list = [k.strip() for k in keywords.split(',') if k.strip()] - return insert_prompt_to_db(title, author, system, user, keyword_list) + result = insert_prompt_to_db(title, author, system, user, keyword_list) + return result save_button.click( fn=save_prompt_to_db, inputs=[title_input, author_input, system_input, user_input, keywords_input], - outputs=save_output + outputs=[save_output] + ) + + # Adding pagination controls to navigate prompts in the database + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[prompts_dropdown, page_display, current_page_state] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[prompts_dropdown, page_display, current_page_state] ) + # Function to update prompts dropdown after saving to the database def update_prompt_dropdown(): - return gr.update(choices=load_preset_prompts()) + prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(visible=True), + gr.update(value=page_display_text, visible=True), + current_page, + total_pages + ) + # Update the dropdown after saving save_button.click( fn=update_prompt_dropdown, inputs=[], - outputs=[gr.Dropdown(label="Select Preset Prompt")] + outputs=[prompts_dropdown, prev_page_button, page_display, current_page_state, total_pages_state] ) @@ -385,4 +446,392 @@ def import_obsidian_vault(vault_path, progress=gr.Progress()): except Exception as e: error_msg = f"Error scanning vault: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) - return 0, 0, [error_msg] \ No newline at end of file + return 0, 0, [error_msg] + + +class RAGQABatchImporter: + def __init__(self, db_path: str): + self.db_path = Path(db_path) + self.setup_logging() + self.file_processor = FileProcessor() + self.zip_validator = ZipValidator() + + def setup_logging(self): + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('rag_qa_import.log'), + logging.StreamHandler() + ] + ) + + def process_markdown_content(self, content: str) -> List[Dict[str, str]]: + """Process markdown content into a conversation format.""" + messages = [] + sections = content.split('\n\n') + + for section in sections: + if section.strip(): + messages.append({ + 'role': 'user', + 'content': section.strip() + }) + + return messages + + def process_keywords(self, db: sqlite3.Connection, conversation_id: str, keywords: str): + """Process and link keywords to a conversation.""" + if not keywords: + return + + keyword_list = [k.strip() for k in keywords.split(',')] + for keyword in keyword_list: + # Insert keyword if it doesn't exist + db.execute(""" + INSERT OR IGNORE INTO rag_qa_keywords (keyword) + VALUES (?) + """, (keyword,)) + + # Get keyword ID + keyword_id = db.execute(""" + SELECT id FROM rag_qa_keywords WHERE keyword = ? + """, (keyword,)).fetchone()[0] + + # Link keyword to conversation + db.execute(""" + INSERT INTO rag_qa_conversation_keywords + (conversation_id, keyword_id) + VALUES (?, ?) + """, (conversation_id, keyword_id)) + + def import_single_file( + self, + db: sqlite3.Connection, + content: str, + filename: str, + keywords: str, + custom_prompt: Optional[str] = None, + rating: Optional[int] = None + ) -> str: + """Import a single file's content into the database""" + conversation_id = str(uuid.uuid4()) + current_time = datetime.now().isoformat() + + # Process filename into title + title = FileProcessor.process_filename_to_title(filename) + if title.lower().endswith(('.md', '.txt')): + title = title[:-3] if title.lower().endswith('.md') else title[:-4] + + # Insert conversation metadata + db.execute(""" + INSERT INTO conversation_metadata + (conversation_id, created_at, last_updated, title, rating) + VALUES (?, ?, ?, ?, ?) + """, (conversation_id, current_time, current_time, title, rating)) + + # Process content and insert messages + messages = self.process_markdown_content(content) + for msg in messages: + db.execute(""" + INSERT INTO rag_qa_chats + (conversation_id, timestamp, role, content) + VALUES (?, ?, ?, ?) + """, (conversation_id, current_time, msg['role'], msg['content'])) + + # Process keywords + self.process_keywords(db, conversation_id, keywords) + + return conversation_id + + def extract_zip(self, zip_path: str) -> List[Tuple[str, str]]: + """Extract and validate files from zip""" + is_valid, error_msg, valid_files = self.zip_validator.validate_zip_file(zip_path) + if not is_valid: + raise ValueError(error_msg) + + files = [] + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + for filename in valid_files: + with zip_ref.open(filename) as f: + content = f.read() + # Try to decode with detected encoding + try: + detected_encoding = detect(content)['encoding'] or 'utf-8' + content = content.decode(detected_encoding) + except UnicodeDecodeError: + content = content.decode('utf-8', errors='replace') + + filename = os.path.basename(filename) + files.append((filename, content)) + return files + + def import_files( + self, + files: List[str], + keywords: str = "", + custom_prompt: Optional[str] = None, + rating: Optional[int] = None, + progress=gr.Progress() + ) -> Tuple[bool, str]: + """Import multiple files or zip files into the RAG QA database.""" + try: + imported_files = [] + + with sqlite3.connect(self.db_path) as db: + # Process each file + for file_path in progress.tqdm(files, desc="Processing files"): + filename = os.path.basename(file_path) + + # Handle zip files + if filename.lower().endswith('.zip'): + zip_files = self.extract_zip(file_path) + for zip_filename, content in progress.tqdm(zip_files, desc=f"Processing files from {filename}"): + conv_id = self.import_single_file( + db=db, + content=content, + filename=zip_filename, + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + imported_files.append(zip_filename) + + # Handle individual markdown/text files + elif filename.lower().endswith(('.md', '.txt')): + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + conv_id = self.import_single_file( + db=db, + content=content, + filename=filename, + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + imported_files.append(filename) + + db.commit() + + return True, f"Successfully imported {len(imported_files)} files:\n" + "\n".join(imported_files) + + except Exception as e: + logging.error(f"Import failed: {str(e)}") + return False, f"Import failed: {str(e)}" + + +class FileProcessor: + """Handles file reading and name processing""" + + VALID_EXTENSIONS = {'.md', '.txt', '.zip'} + ENCODINGS_TO_TRY = [ + 'utf-8', + 'utf-16', + 'windows-1252', + 'iso-8859-1', + 'ascii' + ] + + @staticmethod + def detect_encoding(file_path: str) -> str: + """Detect the file encoding using chardet""" + with open(file_path, 'rb') as file: + raw_data = file.read() + result = detect(raw_data) + return result['encoding'] or 'utf-8' + + @staticmethod + def read_file_content(file_path: str) -> str: + """Read file content with automatic encoding detection""" + detected_encoding = FileProcessor.detect_encoding(file_path) + + # Try detected encoding first + try: + with open(file_path, 'r', encoding=detected_encoding) as f: + return f.read() + except UnicodeDecodeError: + # If detected encoding fails, try others + for encoding in FileProcessor.ENCODINGS_TO_TRY: + try: + with open(file_path, 'r', encoding=encoding) as f: + return f.read() + except UnicodeDecodeError: + continue + + # If all encodings fail, use utf-8 with error handling + with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + return f.read() + + @staticmethod + def process_filename_to_title(filename: str) -> str: + """Convert filename to a readable title""" + # Remove extension + name = os.path.splitext(filename)[0] + + # Look for date patterns + date_pattern = r'(\d{4}[-_]?\d{2}[-_]?\d{2})' + date_match = re.search(date_pattern, name) + date_str = "" + if date_match: + try: + date = datetime.strptime(date_match.group(1).replace('_', '-'), '%Y-%m-%d') + date_str = date.strftime("%b %d, %Y") + name = name.replace(date_match.group(1), '').strip('-_') + except ValueError: + pass + + # Replace separators with spaces + name = re.sub(r'[-_]+', ' ', name) + + # Remove redundant spaces + name = re.sub(r'\s+', ' ', name).strip() + + # Capitalize words, excluding certain words + exclude_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with'} + words = name.split() + capitalized = [] + for i, word in enumerate(words): + if i == 0 or word not in exclude_words: + capitalized.append(word.capitalize()) + else: + capitalized.append(word.lower()) + name = ' '.join(capitalized) + + # Add date if found + if date_str: + name = f"{name} - {date_str}" + + return name + + +class ZipValidator: + """Validates zip file contents and structure""" + + MAX_ZIP_SIZE = 100 * 1024 * 1024 # 100MB + MAX_FILES = 100 + VALID_EXTENSIONS = {'.md', '.txt'} + + @staticmethod + def validate_zip_file(zip_path: str) -> Tuple[bool, str, List[str]]: + """ + Validate zip file and its contents + Returns: (is_valid, error_message, valid_files) + """ + try: + # Check zip file size + if os.path.getsize(zip_path) > ZipValidator.MAX_ZIP_SIZE: + return False, "Zip file too large (max 100MB)", [] + + valid_files = [] + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + # Check number of files + if len(zip_ref.filelist) > ZipValidator.MAX_FILES: + return False, f"Too many files in zip (max {ZipValidator.MAX_FILES})", [] + + # Check for directory traversal attempts + for file_info in zip_ref.filelist: + if '..' in file_info.filename or file_info.filename.startswith('/'): + return False, "Invalid file paths detected", [] + + # Validate each file + total_size = 0 + for file_info in zip_ref.filelist: + # Skip directories + if file_info.filename.endswith('/'): + continue + + # Check file size + if file_info.file_size > ZipValidator.MAX_ZIP_SIZE: + return False, f"File {file_info.filename} too large", [] + + total_size += file_info.file_size + if total_size > ZipValidator.MAX_ZIP_SIZE: + return False, "Total uncompressed size too large", [] + + # Check file extension + ext = os.path.splitext(file_info.filename)[1].lower() + if ext in ZipValidator.VALID_EXTENSIONS: + valid_files.append(file_info.filename) + + if not valid_files: + return False, "No valid markdown or text files found in zip", [] + + return True, "", valid_files + + except zipfile.BadZipFile: + return False, "Invalid or corrupted zip file", [] + except Exception as e: + return False, f"Error processing zip file: {str(e)}", [] + + +def create_conversation_import_tab() -> gr.Tab: + """Create the import tab for the Gradio interface""" + with gr.Tab("Import RAG Chats") as tab: + gr.Markdown("# Import RAG Chats into the Database") + gr.Markdown(""" + Import your RAG Chat markdown/text files individually or as a zip archive + + Supported file types: + - Markdown (.md) + - Text (.txt) + - Zip archives containing .md or .txt files + + Maximum zip file size: 100MB + Maximum files per zip: 100 + """) + with gr.Row(): + with gr.Column(): + import_files = gr.File( + label="Upload Files", + file_types=["txt", "md", "zip"], + file_count="multiple" + ) + + keywords_input = gr.Textbox( + label="Keywords", + placeholder="Enter keywords to apply to all imported files (comma-separated)" + ) + + custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter a custom prompt for processing (optional)" + ) + + rating_input = gr.Slider( + minimum=1, + maximum=3, + step=1, + label="Rating (1-3)", + value=None + ) + + with gr.Column(): + import_button = gr.Button("Import Files") + import_output = gr.Textbox( + label="Import Status", + lines=10 + ) + + def handle_import(files, keywords, custom_prompt, rating): + importer = RAGQABatchImporter("rag_qa.db") # Update with your DB path + success, message = importer.import_files( + files=[f.name for f in files], + keywords=keywords, + custom_prompt=custom_prompt, + rating=rating + ) + return message + + import_button.click( + fn=handle_import, + inputs=[ + import_files, + keywords_input, + custom_prompt_input, + rating_input + ], + outputs=import_output + ) + + return tab diff --git a/App_Function_Libraries/Gradio_UI/Keywords.py b/App_Function_Libraries/Gradio_UI/Keywords.py index 71294ad7a64e061672d847d3eb423ae407b78d2f..075cbc1b94d82ce928450e9ad0a8e2782f3b0e06 100644 --- a/App_Function_Libraries/Gradio_UI/Keywords.py +++ b/App_Function_Libraries/Gradio_UI/Keywords.py @@ -4,22 +4,29 @@ # The Keywords tab allows the user to add, delete, view, and export keywords from the database. # # Imports: - # # External Imports import gradio as gr + +from App_Function_Libraries.DB.Character_Chat_DB import view_char_keywords, add_char_keywords, delete_char_keyword, \ + export_char_keywords_to_csv # # Internal Imports from App_Function_Libraries.DB.DB_Manager import add_keyword, delete_keyword, keywords_browser_interface, export_keywords_to_csv -# +from App_Function_Libraries.DB.Prompts_DB import view_prompt_keywords, delete_prompt_keyword, \ + export_prompt_keywords_to_csv +from App_Function_Libraries.DB.RAG_QA_Chat_DB import view_rag_keywords, get_all_collections, \ + get_keywords_for_collection, create_keyword_collection, add_keyword_to_collection, delete_rag_keyword, \ + export_rag_keywords_to_csv + + # ###################################################################################################################### # # Functions: - def create_export_keywords_tab(): - with gr.TabItem("Export Keywords", visible=True): + with gr.TabItem("Export MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): export_keywords_button = gr.Button("Export Keywords") @@ -33,8 +40,8 @@ def create_export_keywords_tab(): ) def create_view_keywords_tab(): - with gr.TabItem("View Keywords", visible=True): - gr.Markdown("# Browse Keywords") + with gr.TabItem("View MediaDB Keywords", visible=True): + gr.Markdown("# Browse MediaDB Keywords") with gr.Column(): browse_output = gr.Markdown() browse_button = gr.Button("View Existing Keywords") @@ -42,7 +49,7 @@ def create_view_keywords_tab(): def create_add_keyword_tab(): - with gr.TabItem("Add Keywords", visible=True): + with gr.TabItem("Add MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Add Keywords to the Database") @@ -54,7 +61,7 @@ def create_add_keyword_tab(): def create_delete_keyword_tab(): - with gr.Tab("Delete Keywords", visible=True): + with gr.Tab("Delete MediaDB Keywords", visible=True): with gr.Row(): with gr.Column(): gr.Markdown("# Delete Keywords from the Database") @@ -63,3 +70,289 @@ def create_delete_keyword_tab(): with gr.Row(): delete_output = gr.Textbox(label="Result") delete_button.click(fn=delete_keyword, inputs=delete_input, outputs=delete_output) + +# +# End of Media DB Keyword tabs +########################################################## + + +############################################################ +# +# Character DB Keyword functions + +def create_character_keywords_tab(): + """Creates the Character Keywords management tab""" + with gr.Tab("Character Keywords"): + gr.Markdown("# Character Keywords Management") + + with gr.Tabs(): + # View Character Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_char_keywords = gr.Button("Refresh Character Keywords") + char_keywords_output = gr.Markdown() + view_char_keywords() + refresh_char_keywords.click( + fn=view_char_keywords, + outputs=char_keywords_output + ) + + # Add Character Keywords Tab + with gr.TabItem("Add Keywords"): + with gr.Column(): + char_name = gr.Textbox(label="Character Name") + new_keywords = gr.Textbox(label="New Keywords (comma-separated)") + add_char_keyword_btn = gr.Button("Add Keywords") + add_char_result = gr.Markdown() + + add_char_keyword_btn.click( + fn=add_char_keywords, + inputs=[char_name, new_keywords], + outputs=add_char_result + ) + + # Delete Character Keywords Tab (New) + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_char_name = gr.Textbox(label="Character Name") + delete_char_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_char_keyword_btn = gr.Button("Delete Keyword") + delete_char_result = gr.Markdown() + + delete_char_keyword_btn.click( + fn=delete_char_keyword, + inputs=[delete_char_name, delete_char_keyword_input], + outputs=delete_char_result + ) + + # Export Character Keywords Tab (New) + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_char_keywords_btn = gr.Button("Export Character Keywords") + export_char_file = gr.File(label="Download Exported Keywords") + export_char_status = gr.Textbox(label="Export Status") + + export_char_keywords_btn.click( + fn=export_char_keywords_to_csv, + outputs=[export_char_status, export_char_file] + ) + +# +# End of Character Keywords tab +########################################################## + +############################################################ +# +# RAG QA Keywords functions + +def create_rag_qa_keywords_tab(): + """Creates the RAG QA Keywords management tab""" + with gr.Tab("RAG QA Keywords"): + gr.Markdown("# RAG QA Keywords Management") + + with gr.Tabs(): + # View RAG QA Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_rag_keywords = gr.Button("Refresh RAG QA Keywords") + rag_keywords_output = gr.Markdown() + + view_rag_keywords() + + refresh_rag_keywords.click( + fn=view_rag_keywords, + outputs=rag_keywords_output + ) + + # Add RAG QA Keywords Tab + with gr.TabItem("Add Keywords"): + with gr.Column(): + new_rag_keywords = gr.Textbox(label="New Keywords (comma-separated)") + add_rag_keyword_btn = gr.Button("Add Keywords") + add_rag_result = gr.Markdown() + + add_rag_keyword_btn.click( + fn=add_keyword, + inputs=new_rag_keywords, + outputs=add_rag_result + ) + + # Delete RAG QA Keywords Tab (New) + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_rag_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_rag_keyword_btn = gr.Button("Delete Keyword") + delete_rag_result = gr.Markdown() + + delete_rag_keyword_btn.click( + fn=delete_rag_keyword, + inputs=delete_rag_keyword_input, + outputs=delete_rag_result + ) + + # Export RAG QA Keywords Tab (New) + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_rag_keywords_btn = gr.Button("Export RAG QA Keywords") + export_rag_file = gr.File(label="Download Exported Keywords") + export_rag_status = gr.Textbox(label="Export Status") + + export_rag_keywords_btn.click( + fn=export_rag_keywords_to_csv, + outputs=[export_rag_status, export_rag_file] + ) + +# +# End of RAG QA Keywords tab +########################################################## + + +############################################################ +# +# Prompt Keywords functions + +def create_prompt_keywords_tab(): + """Creates the Prompt Keywords management tab""" + with gr.Tab("Prompt Keywords"): + gr.Markdown("# Prompt Keywords Management") + + with gr.Tabs(): + # View Keywords Tab + with gr.TabItem("View Keywords"): + with gr.Column(): + refresh_prompt_keywords = gr.Button("Refresh Prompt Keywords") + prompt_keywords_output = gr.Markdown() + + refresh_prompt_keywords.click( + fn=view_prompt_keywords, + outputs=prompt_keywords_output + ) + + # Add Keywords Tab (using existing prompt management functions) + with gr.TabItem("Add Keywords"): + gr.Markdown(""" + To add keywords to prompts, please use the Prompt Management interface. + Keywords can be added when creating or editing a prompt. + """) + + # Delete Keywords Tab + with gr.TabItem("Delete Keywords"): + with gr.Column(): + delete_prompt_keyword_input = gr.Textbox(label="Keyword to Delete") + delete_prompt_keyword_btn = gr.Button("Delete Keyword") + delete_prompt_result = gr.Markdown() + + delete_prompt_keyword_btn.click( + fn=delete_prompt_keyword, + inputs=delete_prompt_keyword_input, + outputs=delete_prompt_result + ) + + # Export Keywords Tab + with gr.TabItem("Export Keywords"): + with gr.Column(): + export_prompt_keywords_btn = gr.Button("Export Prompt Keywords") + export_prompt_status = gr.Textbox(label="Export Status", interactive=False) + export_prompt_file = gr.File(label="Download Exported Keywords", interactive=False) + + def handle_export(): + status, file_path = export_prompt_keywords_to_csv() + if file_path: + return status, file_path + return status, None + + export_prompt_keywords_btn.click( + fn=handle_export, + outputs=[export_prompt_status, export_prompt_file] + ) +# +# End of Prompt Keywords tab +############################################################ + + +############################################################ +# +# Meta-Keywords functions + +def create_meta_keywords_tab(): + """Creates the Meta-Keywords management tab""" + with gr.Tab("Meta-Keywords"): + gr.Markdown("# Meta-Keywords Management") + + with gr.Tabs(): + # View Meta-Keywords Tab + with gr.TabItem("View Collections"): + with gr.Column(): + refresh_collections = gr.Button("Refresh Collections") + collections_output = gr.Markdown() + + def view_collections(): + try: + collections, _, _ = get_all_collections() + if collections: + result = "### Keyword Collections:\n" + for collection in collections: + keywords = get_keywords_for_collection(collection) + result += f"\n**{collection}**:\n" + result += "\n".join([f"- {k}" for k in keywords]) + result += "\n" + return result + return "No collections found." + except Exception as e: + return f"Error retrieving collections: {str(e)}" + + refresh_collections.click( + fn=view_collections, + outputs=collections_output + ) + + # Create Collection Tab + with gr.TabItem("Create Collection"): + with gr.Column(): + collection_name = gr.Textbox(label="Collection Name") + create_collection_btn = gr.Button("Create Collection") + create_result = gr.Markdown() + + def create_collection(name: str): + try: + create_keyword_collection(name) + return f"Successfully created collection: {name}" + except Exception as e: + return f"Error creating collection: {str(e)}" + + create_collection_btn.click( + fn=create_collection, + inputs=collection_name, + outputs=create_result + ) + + # Add Keywords to Collection Tab + with gr.TabItem("Add to Collection"): + with gr.Column(): + collection_select = gr.Textbox(label="Collection Name") + keywords_to_add = gr.Textbox(label="Keywords to Add (comma-separated)") + add_to_collection_btn = gr.Button("Add Keywords to Collection") + add_to_collection_result = gr.Markdown() + + def add_keywords_to_collection(collection: str, keywords: str): + try: + keywords_list = [k.strip() for k in keywords.split(",") if k.strip()] + for keyword in keywords_list: + add_keyword_to_collection(collection, keyword) + return f"Successfully added {len(keywords_list)} keywords to collection {collection}" + except Exception as e: + return f"Error adding keywords to collection: {str(e)}" + + add_to_collection_btn.click( + fn=add_keywords_to_collection, + inputs=[collection_select, keywords_to_add], + outputs=add_to_collection_result + ) + +# +# End of Meta-Keywords tab +########################################################## + +# +# End of Keywords.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Live_Recording.py b/App_Function_Libraries/Gradio_UI/Live_Recording.py index 158292097568b4a5a3ca84547f38ba3cd63f6ebb..19f9452e4ce41f329b15384affc1c7b0069d4f38 100644 --- a/App_Function_Libraries/Gradio_UI/Live_Recording.py +++ b/App_Function_Libraries/Gradio_UI/Live_Recording.py @@ -13,6 +13,8 @@ from App_Function_Libraries.Audio.Audio_Transcription_Lib import (record_audio, stop_recording) from App_Function_Libraries.DB.DB_Manager import add_media_to_database from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name + # ####################################################################################################################### # @@ -22,6 +24,16 @@ whisper_models = ["small", "medium", "small.en", "medium.en", "medium", "large", "distil-large-v2", "distil-medium.en", "distil-small.en"] def create_live_recording_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.Tab("Live Recording and Transcription", visible=True): gr.Markdown("# Live Audio Recording and Transcription") with gr.Row(): @@ -34,6 +46,13 @@ def create_live_recording_tab(): custom_title = gr.Textbox(label="Custom Title (for database)", visible=False) record_button = gr.Button("Start Recording") stop_button = gr.Button("Stop Recording") + # FIXME - Add a button to perform analysis/summarization on the transcription + # Refactored API selection dropdown + # api_name_input = gr.Dropdown( + # choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + # value=default_value, + # label="API for Summarization (Optional)" + # ) with gr.Column(): output = gr.Textbox(label="Transcription", lines=10) audio_output = gr.Audio(label="Recorded Audio", visible=False) diff --git a/App_Function_Libraries/Gradio_UI/Llamafile_tab.py b/App_Function_Libraries/Gradio_UI/Llamafile_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..1b092f81f7a73b0b3563e734dfd07791359d2cdd --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Llamafile_tab.py @@ -0,0 +1,327 @@ +# Llamafile_tab.py +# Description: Gradio interface for configuring and launching Llamafile with Local LLMs + +# Imports +import os +import logging +from typing import Tuple, Optional +import gradio as gr + + +from App_Function_Libraries.Local_LLM.Local_LLM_Inference_Engine_Lib import ( + download_llm_model, + llm_models, + start_llamafile, + get_gguf_llamafile_files +) +# +####################################################################################################################### +# +# Functions: + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +MODELS_DIR = os.path.join(BASE_DIR, "Models") + +def create_chat_with_llamafile_tab(): + # Function to update model path based on selection + def on_local_model_change(selected_model: str, search_directory: str) -> str: + if selected_model and isinstance(search_directory, str): + model_path = os.path.abspath(os.path.join(search_directory, selected_model)) + logging.debug(f"Selected model path: {model_path}") # Debug print for selected model path + return model_path + return "Invalid selection or directory." + + # Function to update the dropdown with available models + def update_dropdowns(search_directory: str) -> Tuple[dict, str]: + logging.debug(f"User-entered directory: {search_directory}") # Debug print for directory + if not os.path.isdir(search_directory): + logging.debug(f"Directory does not exist: {search_directory}") # Debug print for non-existing directory + return gr.update(choices=[], value=None), "Directory does not exist." + + try: + logging.debug(f"Directory exists: {search_directory}, scanning for files...") # Confirm directory exists + model_files = get_gguf_llamafile_files(search_directory) + logging.debug("Completed scanning for model files.") + except Exception as e: + logging.error(f"Error scanning directory: {e}") + return gr.update(choices=[], value=None), f"Error scanning directory: {e}" + + if not model_files: + logging.debug(f"No model files found in {search_directory}") # Debug print for no files found + return gr.update(choices=[], value=None), "No model files found in the specified directory." + + # Update the dropdown choices with the model files found + logging.debug(f"Models loaded from {search_directory}: {model_files}") # Debug: Print model files loaded + return gr.update(choices=model_files, value=None), f"Models loaded from {search_directory}." + + + + def download_preset_model(selected_model: str) -> Tuple[str, str]: + """ + Downloads the selected preset model. + + Args: + selected_model (str): The key of the selected preset model. + + Returns: + Tuple[str, str]: Status message and the path to the downloaded model. + """ + model_info = llm_models.get(selected_model) + if not model_info: + return "Invalid model selection.", "" + + try: + model_path = download_llm_model( + model_name=model_info["name"], + model_url=model_info["url"], + model_filename=model_info["filename"], + model_hash=model_info["hash"] + ) + return f"Model '{model_info['name']}' downloaded successfully.", model_path + except Exception as e: + logging.error(f"Error downloading model: {e}") + return f"Failed to download model: {e}", "" + + with gr.TabItem("Local LLM with Llamafile", visible=True): + gr.Markdown("# Settings for Llamafile") + + with gr.Row(): + with gr.Column(): + am_noob = gr.Checkbox(label="Enable Sane Defaults", value=False, visible=True) + advanced_mode_toggle = gr.Checkbox(label="Advanced Mode - Show All Settings", value=False) + # Advanced Inputs + verbose_checked = gr.Checkbox(label="Enable Verbose Output", value=False, visible=False) + threads_checked = gr.Checkbox(label="Set CPU Threads", value=False, visible=False) + threads_value = gr.Number(label="Number of CPU Threads", value=None, precision=0, visible=False) + threads_batched_checked = gr.Checkbox(label="Enable Batched Inference", value=False, visible=False) + threads_batched_value = gr.Number(label="Batch Size for Inference", value=None, precision=0, visible=False) + model_alias_checked = gr.Checkbox(label="Set Model Alias", value=False, visible=False) + model_alias_value = gr.Textbox(label="Model Alias", value="", visible=False) + ctx_size_checked = gr.Checkbox(label="Set Prompt Context Size", value=False, visible=False) + ctx_size_value = gr.Number(label="Prompt Context Size", value=8124, precision=0, visible=False) + ngl_checked = gr.Checkbox(label="Enable GPU Layers", value=False, visible=True) + ngl_value = gr.Number(label="Number of GPU Layers", value=None, precision=0, visible=True) + batch_size_checked = gr.Checkbox(label="Set Batch Size", value=False, visible=False) + batch_size_value = gr.Number(label="Batch Size", value=512, visible=False) + memory_f32_checked = gr.Checkbox(label="Use 32-bit Floating Point", value=False, visible=False) + numa_checked = gr.Checkbox(label="Enable NUMA", value=False, visible=False) + server_timeout_value = gr.Number(label="Server Timeout", value=600, precision=0, visible=False) + host_checked = gr.Checkbox(label="Set IP to Listen On", value=False, visible=False) + host_value = gr.Textbox(label="Host IP Address", value="", visible=False) + port_checked = gr.Checkbox(label="Set Server Port", value=False, visible=False) + port_value = gr.Number(label="Port Number", value=8080, precision=0, visible=False) + api_key_checked = gr.Checkbox(label="Set API Key", value=False, visible=False) + api_key_value = gr.Textbox(label="API Key", value="", visible=False) + http_threads_checked = gr.Checkbox(label="Set HTTP Server Threads", value=False, visible=False) + http_threads_value = gr.Number(label="Number of HTTP Server Threads", value=None, precision=0, visible=False) + hf_repo_checked = gr.Checkbox(label="Use Huggingface Repo Model", value=False, visible=False) + hf_repo_value = gr.Textbox(label="Huggingface Repo Name", value="", visible=False) + hf_file_checked = gr.Checkbox(label="Set Huggingface Model File", value=False, visible=False) + hf_file_value = gr.Textbox(label="Huggingface Model File", value="", visible=False) + + with gr.Column(): + # Model Selection Section + gr.Markdown("## Model Selection") + + # Option 1: Select from Local Filesystem + with gr.Row(): + search_directory = gr.Textbox( + label="Model Directory", + placeholder="Enter directory path (currently './Models')", + value=MODELS_DIR, + interactive=True + ) + + # Initial population of local models + initial_dropdown_update, _ = update_dropdowns(MODELS_DIR) + logging.debug(f"Scanning directory: {MODELS_DIR}") + refresh_button = gr.Button("Refresh Models") + local_model_dropdown = gr.Dropdown( + label="Select Model from Directory", + choices=initial_dropdown_update["choices"], + value=None + ) + # Display selected model path + model_value = gr.Textbox(label="Selected Model File Path", value="", interactive=False) + + # Option 2: Download Preset Models + gr.Markdown("## Download Preset Models") + + preset_model_dropdown = gr.Dropdown( + label="Select a Preset Model", + choices=list(llm_models.keys()), + value=None, + interactive=True, + info="Choose a preset model to download." + ) + download_preset_button = gr.Button("Download Selected Preset") + + with gr.Row(): + with gr.Column(): + start_button = gr.Button("Start Llamafile") + stop_button = gr.Button("Stop Llamafile (doesn't work)") + output_display = gr.Markdown() + + + # Show/hide advanced inputs based on toggle + def update_visibility(show_advanced: bool): + components = [ + verbose_checked, threads_checked, threads_value, + http_threads_checked, http_threads_value, + hf_repo_checked, hf_repo_value, + hf_file_checked, hf_file_value, + ctx_size_checked, ctx_size_value, + ngl_checked, ngl_value, + host_checked, host_value, + port_checked, port_value + ] + return [gr.update(visible=show_advanced) for _ in components] + + def on_start_button_click( + am_noob: bool, + verbose_checked: bool, + threads_checked: bool, + threads_value: Optional[int], + threads_batched_checked: bool, + threads_batched_value: Optional[int], + model_alias_checked: bool, + model_alias_value: str, + http_threads_checked: bool, + http_threads_value: Optional[int], + model_value: str, + hf_repo_checked: bool, + hf_repo_value: str, + hf_file_checked: bool, + hf_file_value: str, + ctx_size_checked: bool, + ctx_size_value: Optional[int], + ngl_checked: bool, + ngl_value: Optional[int], + batch_size_checked: bool, + batch_size_value: Optional[int], + memory_f32_checked: bool, + numa_checked: bool, + server_timeout_value: Optional[int], + host_checked: bool, + host_value: str, + port_checked: bool, + port_value: Optional[int], + api_key_checked: bool, + api_key_value: str + ) -> str: + """ + Event handler for the Start Llamafile button. + """ + try: + result = start_llamafile( + am_noob, + verbose_checked, + threads_checked, + threads_value, + threads_batched_checked, + threads_batched_value, + model_alias_checked, + model_alias_value, + http_threads_checked, + http_threads_value, + model_value, + hf_repo_checked, + hf_repo_value, + hf_file_checked, + hf_file_value, + ctx_size_checked, + ctx_size_value, + ngl_checked, + ngl_value, + batch_size_checked, + batch_size_value, + memory_f32_checked, + numa_checked, + server_timeout_value, + host_checked, + host_value, + port_checked, + port_value, + api_key_checked, + api_key_value + ) + return result + except Exception as e: + logging.error(f"Error starting Llamafile: {e}") + return f"Failed to start Llamafile: {e}" + + advanced_mode_toggle.change( + fn=update_visibility, + inputs=[advanced_mode_toggle], + outputs=[ + verbose_checked, threads_checked, threads_value, + http_threads_checked, http_threads_value, + hf_repo_checked, hf_repo_value, + hf_file_checked, hf_file_value, + ctx_size_checked, ctx_size_value, + ngl_checked, ngl_value, + host_checked, host_value, + port_checked, port_value + ] + ) + + start_button.click( + fn=on_start_button_click, + inputs=[ + am_noob, + verbose_checked, + threads_checked, + threads_value, + threads_batched_checked, + threads_batched_value, + model_alias_checked, + model_alias_value, + http_threads_checked, + http_threads_value, + model_value, + hf_repo_checked, + hf_repo_value, + hf_file_checked, + hf_file_value, + ctx_size_checked, + ctx_size_value, + ngl_checked, + ngl_value, + batch_size_checked, + batch_size_value, + memory_f32_checked, + numa_checked, + server_timeout_value, + host_checked, + host_value, + port_checked, + port_value, + api_key_checked, + api_key_value + ], + outputs=output_display + ) + + download_preset_button.click( + fn=download_preset_model, + inputs=[preset_model_dropdown], + outputs=[output_display, model_value] + ) + + # Click event for refreshing models + refresh_button.click( + fn=update_dropdowns, + inputs=[search_directory], # Ensure that the directory path (string) is passed + outputs=[local_model_dropdown, output_display] # Update dropdown and status + ) + + # Event to update model_value when a model is selected from the dropdown + local_model_dropdown.change( + fn=on_local_model_change, # Function that calculates the model path + inputs=[local_model_dropdown, search_directory], # Inputs: selected model and directory + outputs=[model_value] # Output: Update the model_value textbox with the selected model path + ) + +# +# +####################################################################################################################### \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Media_edit.py b/App_Function_Libraries/Gradio_UI/Media_edit.py index e3f57e52b3430c0f781a5389f5206de9dbd105b9..b845c3686bbc13424079a7ffba30b5c4bcfea6cf 100644 --- a/App_Function_Libraries/Gradio_UI/Media_edit.py +++ b/App_Function_Libraries/Gradio_UI/Media_edit.py @@ -10,13 +10,13 @@ import gradio as gr # # Local Imports from App_Function_Libraries.DB.DB_Manager import add_prompt, update_media_content, db, add_or_update_prompt, \ - load_prompt_details, fetch_keywords_for_media, update_keywords_for_media -from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_prompt_dropdown + load_prompt_details, fetch_keywords_for_media, update_keywords_for_media, fetch_prompt_details, list_prompts +from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown from App_Function_Libraries.DB.SQLite_DB import fetch_item_details def create_media_edit_tab(): - with gr.TabItem("Edit Existing Items", visible=True): + with gr.TabItem("Edit Existing Items in the Media DB", visible=True): gr.Markdown("# Search and Edit Media Items") with gr.Row(): @@ -89,7 +89,7 @@ def create_media_edit_tab(): def create_media_edit_and_clone_tab(): - with gr.TabItem("Clone and Edit Existing Items", visible=True): + with gr.TabItem("Clone and Edit Existing Items in the Media DB", visible=True): gr.Markdown("# Search, Edit, and Clone Existing Items") with gr.Row(): @@ -199,6 +199,11 @@ def create_media_edit_and_clone_tab(): def create_prompt_edit_tab(): + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + per_page = 10 # Number of prompts per page + with gr.TabItem("Add & Edit Prompts", visible=True): with gr.Row(): with gr.Column(): @@ -207,38 +212,145 @@ def create_prompt_edit_tab(): choices=[], interactive=True ) + next_page_button = gr.Button("Next Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) prompt_list_button = gr.Button("List Prompts") with gr.Column(): title_input = gr.Textbox(label="Title", placeholder="Enter the prompt title") - author_input = gr.Textbox(label="Author", placeholder="Enter the prompt's author", lines=3) + author_input = gr.Textbox(label="Author", placeholder="Enter the prompt's author", lines=1) description_input = gr.Textbox(label="Description", placeholder="Enter the prompt description", lines=3) system_prompt_input = gr.Textbox(label="System Prompt", placeholder="Enter the system prompt", lines=3) user_prompt_input = gr.Textbox(label="User Prompt", placeholder="Enter the user prompt", lines=3) add_prompt_button = gr.Button("Add/Update Prompt") add_prompt_output = gr.HTML() - # Event handlers + # Function to update the prompt dropdown with pagination + def update_prompt_dropdown(page=1): + prompts, total_pages, current_page = list_prompts(page=page, per_page=per_page) + page_display_text = f"Page {current_page} of {total_pages}" + prev_button_visible = current_page > 1 + next_button_visible = current_page < total_pages + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text, visible=True), + gr.update(visible=prev_button_visible), + gr.update(visible=next_button_visible), + current_page, + total_pages + ) + + # Event handler for listing prompts prompt_list_button.click( fn=update_prompt_dropdown, - outputs=prompt_dropdown + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + # Functions to handle pagination + def on_prev_page_click(current_page): + new_page = max(current_page - 1, 1) + return update_prompt_dropdown(page=new_page) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + return update_prompt_dropdown(page=new_page) + + # Event handlers for pagination buttons + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) + # Event handler for adding or updating a prompt add_prompt_button.click( fn=add_or_update_prompt, inputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input], - outputs=add_prompt_output + outputs=[add_prompt_output] + ).then( + fn=update_prompt_dropdown, + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) - # Load prompt details when selected + # Function to load prompt details when a prompt is selected + def load_prompt_details(selected_prompt): + details = fetch_prompt_details(selected_prompt) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + return ( + gr.update(value=title), + gr.update(value=author or ""), + gr.update(value=description or ""), + gr.update(value=system_prompt or ""), + gr.update(value=user_prompt or "") + ) + else: + return ( + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value="") + ) + + # Event handler for prompt selection change prompt_dropdown.change( fn=load_prompt_details, inputs=[prompt_dropdown], - outputs=[title_input, author_input, system_prompt_input, user_prompt_input] + outputs=[ + title_input, + author_input, + description_input, + system_prompt_input, + user_prompt_input + ] ) + def create_prompt_clone_tab(): + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + per_page = 10 # Number of prompts per page + with gr.TabItem("Clone and Edit Prompts", visible=True): with gr.Row(): with gr.Column(): @@ -248,6 +360,9 @@ def create_prompt_clone_tab(): choices=[], interactive=True ) + next_page_button = gr.Button("Next Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) prompt_list_button = gr.Button("List Prompts") with gr.Column(): @@ -260,19 +375,99 @@ def create_prompt_clone_tab(): save_cloned_prompt_button = gr.Button("Save Cloned Prompt", visible=False) add_prompt_output = gr.HTML() - # Event handlers + # Function to update the prompt dropdown with pagination + def update_prompt_dropdown(page=1): + prompts, total_pages, current_page = list_prompts(page=page, per_page=per_page) + page_display_text = f"Page {current_page} of {total_pages}" + prev_button_visible = current_page > 1 + next_button_visible = current_page < total_pages + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text, visible=True), + gr.update(visible=prev_button_visible), + gr.update(visible=next_button_visible), + current_page, + total_pages + ) + + # Event handler for listing prompts prompt_list_button.click( fn=update_prompt_dropdown, - outputs=prompt_dropdown + inputs=[], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + # Functions to handle pagination + def on_prev_page_click(current_page): + new_page = max(current_page - 1, 1) + return update_prompt_dropdown(page=new_page) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + return update_prompt_dropdown(page=new_page) + + # Event handlers for pagination buttons + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[ + prompt_dropdown, + page_display, + prev_page_button, + next_page_button, + current_page_state, + total_pages_state + ] ) # Load prompt details when selected + def load_prompt_details(selected_prompt): + if selected_prompt: + details = fetch_prompt_details(selected_prompt) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + return ( + gr.update(value=title), + gr.update(value=author or ""), + gr.update(value=description or ""), + gr.update(value=system_prompt or ""), + gr.update(value=user_prompt or "") + ) + return ( + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value=""), + gr.update(value="") + ) + prompt_dropdown.change( fn=load_prompt_details, inputs=[prompt_dropdown], outputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input] ) + # Prepare for cloning def prepare_for_cloning(selected_prompt): if selected_prompt: return gr.update(value=f"Copy of {selected_prompt}"), gr.update(visible=True) @@ -284,18 +479,21 @@ def create_prompt_clone_tab(): outputs=[title_input, save_cloned_prompt_button] ) - def save_cloned_prompt(title, description, system_prompt, user_prompt): + # Function to save cloned prompt + def save_cloned_prompt(title, author, description, system_prompt, user_prompt, current_page): try: - result = add_prompt(title, description, system_prompt, user_prompt) + result = add_prompt(title, author, description, system_prompt, user_prompt) if result == "Prompt added successfully.": - return result, gr.update(choices=update_prompt_dropdown()) + # After adding, refresh the prompt dropdown + prompt_dropdown_update = update_prompt_dropdown(page=current_page) + return (result, *prompt_dropdown_update) else: - return result, gr.update() + return (result, gr.update(), gr.update(), gr.update(), gr.update(), current_page, total_pages_state.value) except Exception as e: - return f"Error saving cloned prompt: {str(e)}", gr.update() + return (f"Error saving cloned prompt: {str(e)}", gr.update(), gr.update(), gr.update(), gr.update(), current_page, total_pages_state.value) save_cloned_prompt_button.click( fn=save_cloned_prompt, - inputs=[title_input, description_input, system_prompt_input, user_prompt_input], - outputs=[add_prompt_output, prompt_dropdown] - ) \ No newline at end of file + inputs=[title_input, author_input, description_input, system_prompt_input, user_prompt_input, current_page_state], + outputs=[add_prompt_output, prompt_dropdown, page_display, prev_page_button, next_page_button, current_page_state, total_pages_state] + ) diff --git a/App_Function_Libraries/Gradio_UI/Media_wiki_tab.py b/App_Function_Libraries/Gradio_UI/Media_wiki_tab.py index 40cac2389d66dcd4aecd98beae181e2a4d397040..7f0d27f870c30bf2dc90415af3f049724e809555 100644 --- a/App_Function_Libraries/Gradio_UI/Media_wiki_tab.py +++ b/App_Function_Libraries/Gradio_UI/Media_wiki_tab.py @@ -32,6 +32,13 @@ def create_mediawiki_import_tab(): value="sentences", label="Chunking Method" ) + # FIXME - add API selection dropdown + Analysis/Summarization options + # Refactored API selection dropdown + # api_name_input = gr.Dropdown( + # choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + # value=default_value, + # label="API for Summarization (Optional)" + # ) chunk_size = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Chunk Size") chunk_overlap = gr.Slider(minimum=0, maximum=500, value=100, step=10, label="Chunk Overlap") # FIXME - Add checkbox for 'Enable Summarization upon ingestion' for API summarization of chunks diff --git a/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py b/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..a635f9e6309278b240cb9f083c143ba4fd2c8585 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Mind_Map_tab.py @@ -0,0 +1,128 @@ +# Mind_Map_tab.py +# Description: File contains functions for generation of PlantUML mindmaps for the gradio tab +# +# Imports +import re +# +# External Libraries +import gradio as gr +# +###################################################################################################################### +# +# Functions: + +def parse_plantuml_mindmap(plantuml_text: str) -> dict: + """Parse PlantUML mindmap syntax into a nested dictionary structure""" + lines = [line.strip() for line in plantuml_text.split('\n') + if line.strip() and not line.strip().startswith('@')] + + root = None + nodes = [] + stack = [] + + for line in lines: + level_match = re.match(r'^([+\-*]+|\*+)', line) + if not level_match: + continue + level = len(level_match.group(0)) + text = re.sub(r'^([+\-*]+|\*+)\s*', '', line).strip('[]').strip('()') + node = {'text': text, 'children': []} + + while stack and stack[-1][0] >= level: + stack.pop() + + if stack: + stack[-1][1]['children'].append(node) + else: + root = node + + stack.append((level, node)) + + return root + +def create_mindmap_html(plantuml_text: str) -> str: + """Convert PlantUML mindmap to HTML visualization with collapsible nodes using CSS only""" + # Parse the mindmap text into a nested structure + root_node = parse_plantuml_mindmap(plantuml_text) + if not root_node: + return "

No valid mindmap content provided.

" + + html = "" + + colors = ['#e6f3ff', '#f0f7ff', '#f5f5f5', '#fff0f0', '#f0fff0'] + + def create_node_html(node, level): + bg_color = colors[(level - 1) % len(colors)] + if node['children']: + children_html = ''.join(create_node_html(child, level + 1) for child in node['children']) + return f""" +
+ {node['text']} + {children_html} +
+ """ + else: + return f""" +
+ {node['text']} +
+ """ + + html += create_node_html(root_node, level=1) + return html + +# Create Gradio interface +def create_mindmap_tab(): + with gr.TabItem("PlantUML Mindmap"): + gr.Markdown("# Collapsible PlantUML Mindmap Visualizer") + gr.Markdown("Convert PlantUML mindmap syntax to a visual mindmap with collapsible nodes.") + plantuml_input = gr.Textbox( + lines=15, + label="Enter PlantUML mindmap", + placeholder="""@startmindmap + * Project Planning + ** Requirements + *** Functional Requirements + **** User Interface + **** Backend Services + *** Technical Requirements + **** Performance + **** Security + ** Timeline + *** Phase 1 + *** Phase 2 + ** Resources + *** Team + *** Budget + @endmindmap""" + ) + submit_btn = gr.Button("Generate Mindmap") + mindmap_output = gr.HTML(label="Mindmap Output") + submit_btn.click( + fn=create_mindmap_html, + inputs=plantuml_input, + outputs=mindmap_output + ) + +# +# End of Mind_Map_tab.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py b/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py index cd287615660772dbdcc0f39cc59d6c6b0265ed63..697e9d676a82285b8f9ef5a46d74e7bcec94ce6b 100644 --- a/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/PDF_ingestion_tab.py @@ -8,9 +8,12 @@ import tempfile # # External Imports import gradio as gr +import pymupdf4llm +from docling.document_converter import DocumentConverter + # # Local Imports -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_metadata_from_pdf, extract_text_and_format_from_pdf, \ process_and_cleanup_pdf @@ -22,92 +25,258 @@ from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_metadata_from_p def create_pdf_ingestion_tab(): with gr.TabItem("PDF Ingestion", visible=True): - # TODO - Add functionality to extract metadata from pdf as part of conversion process in marker gr.Markdown("# Ingest PDF Files and Extract Metadata") with gr.Row(): with gr.Column(): - pdf_file_input = gr.File(label="Uploaded PDF File", file_types=[".pdf"], visible=True) - pdf_upload_button = gr.UploadButton("Click to Upload PDF", file_types=[".pdf"]) - pdf_title_input = gr.Textbox(label="Title (Optional)") - pdf_author_input = gr.Textbox(label="Author (Optional)") - pdf_keywords_input = gr.Textbox(label="Keywords (Optional, comma-separated)") - with gr.Row(): - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) - with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) - with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) - with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value=""" -You are a bulleted notes specialist. -[INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", - lines=3, - visible=False) - - custom_prompt_checkbox.change( - fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), - inputs=[custom_prompt_checkbox], - outputs=[custom_prompt_input, system_prompt_input] + # Changed to support multiple files + pdf_file_input = gr.File( + label="Uploaded PDF Files", + file_types=[".pdf"], + visible=True, + file_count="multiple" ) - preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), - inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + pdf_upload_button = gr.UploadButton( + "Click to Upload PDFs", + file_types=[".pdf"], + file_count="multiple" ) - - def update_prompts(preset_name): - prompts = update_user_prompt(preset_name) - return ( - gr.update(value=prompts["user_prompt"], visible=True), - gr.update(value=prompts["system_prompt"], visible=True) - ) - - preset_prompt.change( - update_prompts, - inputs=preset_prompt, - outputs=[custom_prompt_input, system_prompt_input] + parser_selection = gr.Radio( + choices=["pymupdf", "pymupdf4llm", "docling"], + label="Select Parser", + value="pymupdf" # default value ) + # Common metadata for all files + pdf_keywords_input = gr.Textbox(label="Keywords (Optional, comma-separated)") +# with gr.Row(): +# custom_prompt_checkbox = gr.Checkbox( +# label="Use a Custom Prompt", +# value=False, +# visible=True +# ) +# preset_prompt_checkbox = gr.Checkbox( +# label="Use a pre-set Prompt", +# value=False, +# visible=True +# ) +# # Initialize state variables for pagination +# current_page_state = gr.State(value=1) +# total_pages_state = gr.State(value=1) +# with gr.Row(): +# # Add pagination controls +# preset_prompt = gr.Dropdown( +# label="Select Preset Prompt", +# choices=[], +# visible=False +# ) +# prev_page_button = gr.Button("Previous Page", visible=False) +# page_display = gr.Markdown("Page 1 of X", visible=False) +# next_page_button = gr.Button("Next Page", visible=False) +# with gr.Row(): +# custom_prompt_input = gr.Textbox( +# label="Custom Prompt", +# placeholder="Enter custom prompt here", +# lines=3, +# visible=False +# ) +# with gr.Row(): +# system_prompt_input = gr.Textbox( +# label="System Prompt", +# value=""" +# You are a bulleted notes specialist. +# [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] +# **Bulleted Note Creation Guidelines** +# +# **Headings**: +# - Based on referenced topics, not categories like quotes or terms +# - Surrounded by **bold** formatting +# - Not listed as bullet points +# - No space between headings and list items underneath +# +# **Emphasis**: +# - **Important terms** set in bold font +# - **Text ending in a colon**: also bolded +# +# **Review**: +# - Ensure adherence to specified format +# - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", +# lines=3, +# visible=False +# ) +# +# custom_prompt_checkbox.change( +# fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), +# inputs=[custom_prompt_checkbox], +# outputs=[custom_prompt_input, system_prompt_input] +# ) +# +# def on_preset_prompt_checkbox_change(is_checked): +# if is_checked: +# prompts, total_pages, current_page = list_prompts(page=1, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return ( +# gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt +# gr.update(visible=True), # prev_page_button +# gr.update(visible=True), # next_page_button +# gr.update(value=page_display_text, visible=True), # page_display +# current_page, # current_page_state +# total_pages # total_pages_state +# ) +# else: +# return ( +# gr.update(visible=False, interactive=False), # preset_prompt +# gr.update(visible=False), # prev_page_button +# gr.update(visible=False), # next_page_button +# gr.update(visible=False), # page_display +# 1, # current_page_state +# 1 # total_pages_state +# ) +# +# preset_prompt_checkbox.change( +# fn=on_preset_prompt_checkbox_change, +# inputs=[preset_prompt_checkbox], +# outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] +# ) +# +# def on_prev_page_click(current_page, total_pages): +# new_page = max(current_page - 1, 1) +# prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return gr.update(choices=prompts), gr.update(value=page_display_text), current_page +# +# prev_page_button.click( +# fn=on_prev_page_click, +# inputs=[current_page_state, total_pages_state], +# outputs=[preset_prompt, page_display, current_page_state] +# ) +# +# def on_next_page_click(current_page, total_pages): +# new_page = min(current_page + 1, total_pages) +# prompts, total_pages, current_page = list_prompts(page=new_page, per_page=10) +# page_display_text = f"Page {current_page} of {total_pages}" +# return gr.update(choices=prompts), gr.update(value=page_display_text), current_page +# +# next_page_button.click( +# fn=on_next_page_click, +# inputs=[current_page_state, total_pages_state], +# outputs=[preset_prompt, page_display, current_page_state] +# ) +# +# def update_prompts(preset_name): +# prompts = update_user_prompt(preset_name) +# return ( +# gr.update(value=prompts["user_prompt"], visible=True), +# gr.update(value=prompts["system_prompt"], visible=True) +# ) +# +# preset_prompt.change( +# update_prompts, +# inputs=preset_prompt, +# outputs=[custom_prompt_input, system_prompt_input] +# ) - pdf_ingest_button = gr.Button("Ingest PDF") + pdf_ingest_button = gr.Button("Ingest PDFs") + + # Update the upload button handler for multiple files + pdf_upload_button.upload( + fn=lambda files: files, + inputs=pdf_upload_button, + outputs=pdf_file_input + ) - pdf_upload_button.upload(fn=lambda file: file, inputs=pdf_upload_button, outputs=pdf_file_input) with gr.Column(): - pdf_result_output = gr.Textbox(label="Result") + pdf_result_output = gr.DataFrame( + headers=["Filename", "Status", "Message"], + label="Processing Results" + ) + + # Define a new function to handle multiple PDFs + def process_multiple_pdfs(pdf_files, keywords, custom_prompt_checkbox_value, custom_prompt_text, system_prompt_text): + results = [] + if pdf_files is None: + return [["No files", "Error", "No files uploaded"]] + + for pdf_file in pdf_files: + try: + # Extract metadata from PDF + metadata = extract_metadata_from_pdf(pdf_file.name) + # Use custom or system prompt if checkbox is checked + if custom_prompt_checkbox_value: + prompt = custom_prompt_text + system_prompt = system_prompt_text + else: + prompt = None + system_prompt = None + + # Process the PDF with prompts + result = process_and_cleanup_pdf( + pdf_file, + metadata.get('title', os.path.splitext(os.path.basename(pdf_file.name))[0]), + metadata.get('author', 'Unknown'), + keywords, + #prompt=prompt, + #system_prompt=system_prompt + ) + + results.append([ + pdf_file.name, + "Success" if "successfully" in result else "Error", + result + ]) + except Exception as e: + results.append([ + pdf_file.name, + "Error", + str(e) + ]) + + return results + + # Update the ingest button click handler pdf_ingest_button.click( - fn=process_and_cleanup_pdf, - inputs=[pdf_file_input, pdf_title_input, pdf_author_input, pdf_keywords_input], + fn=process_multiple_pdfs, + inputs=[ + pdf_file_input, + pdf_keywords_input, + parser_selection, + #custom_prompt_checkbox, + #custom_prompt_input, + #system_prompt_input + ], outputs=pdf_result_output ) -def test_pdf_ingestion(pdf_file): +def test_pymupdf4llm_pdf_ingestion(pdf_file): + if pdf_file is None: + return "No file uploaded", "" + + try: + # Create a temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + # Create a path for the temporary PDF file + temp_path = os.path.join(temp_dir, "temp.pdf") + + # Copy the contents of the uploaded file to the temporary file + shutil.copy(pdf_file.name, temp_path) + + # Extract text and convert to Markdown + markdown_text = pymupdf4llm.to_markdown(temp_path) + + # Extract metadata from PDF + metadata = extract_metadata_from_pdf(temp_path) + + # Use metadata for title and author if not provided + title = metadata.get('title', os.path.splitext(os.path.basename(pdf_file.name))[0]) + author = metadata.get('author', 'Unknown') + + result = f"PDF '{title}' by {author} processed successfully by pymupdf4llm." + return result, markdown_text + except Exception as e: + return f"Error ingesting PDF: {str(e)}", "" + + +def test_pymupdf_pdf_ingestion(pdf_file): if pdf_file is None: return "No file uploaded", "" @@ -130,7 +299,37 @@ def test_pdf_ingestion(pdf_file): title = metadata.get('title', os.path.splitext(os.path.basename(pdf_file.name))[0]) author = metadata.get('author', 'Unknown') - result = f"PDF '{title}' by {author} processed successfully." + result = f"PDF '{title}' by {author} processed successfully by pymupdf." + return result, markdown_text + except Exception as e: + return f"Error ingesting PDF: {str(e)}", "" + + +def test_docling_pdf_ingestion(pdf_file): + if pdf_file is None: + return "No file uploaded", "" + + try: + # Create a temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + # Create a path for the temporary PDF file + temp_path = os.path.join(temp_dir, "temp.pdf") + + # Copy the contents of the uploaded file to the temporary file + shutil.copy(pdf_file.name, temp_path) + + # Extract text and convert to Markdown + converter = DocumentConverter() + parsed_pdf = converter.convert(temp_path) + markdown_text = parsed_pdf.document.export_to_markdown() + # Extract metadata from PDF + metadata = extract_metadata_from_pdf(temp_path) + + # Use metadata for title and author if not provided + title = metadata.get('title', os.path.splitext(os.path.basename(pdf_file.name))[0]) + author = metadata.get('author', 'Unknown') + + result = f"PDF '{title}' by {author} processed successfully by pymupdf." return result, markdown_text except Exception as e: return f"Error ingesting PDF: {str(e)}", "" @@ -140,12 +339,24 @@ def create_pdf_ingestion_test_tab(): with gr.Row(): with gr.Column(): pdf_file_input = gr.File(label="Upload PDF for testing") - test_button = gr.Button("Test PDF Ingestion") + test_button = gr.Button("Test pymupdf PDF Ingestion") + test_button_2 = gr.Button("Test pymupdf4llm PDF Ingestion") + test_button_3 = gr.Button("Test Docling PDF Ingestion") with gr.Column(): test_output = gr.Textbox(label="Test Result") pdf_content_output = gr.Textbox(label="PDF Content", lines=200) test_button.click( - fn=test_pdf_ingestion, + fn=test_pymupdf_pdf_ingestion, + inputs=[pdf_file_input], + outputs=[test_output, pdf_content_output] + ) + test_button_2.click( + fn=test_pymupdf4llm_pdf_ingestion, + inputs=[pdf_file_input], + outputs=[test_output, pdf_content_output] + ) + test_button_3.click( + fn=test_docling_pdf_ingestion, inputs=[pdf_file_input], outputs=[test_output, pdf_content_output] ) diff --git a/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py b/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py index 9491ed6ad67f853488aa3bd12b42920358fe253e..e3103682133a88cadd928faca42ae00528db22b7 100644 --- a/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py +++ b/App_Function_Libraries/Gradio_UI/Plaintext_tab_import.py @@ -6,6 +6,7 @@ ####################################################################################################################### # # Import necessary libraries +import logging import os import tempfile import zipfile @@ -16,101 +17,104 @@ from docx2txt import docx2txt from pypandoc import convert_file # # Import Local libraries -from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data +from App_Function_Libraries.Plaintext.Plaintext_Files import import_file_handler +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name # ####################################################################################################################### # # Functions: def create_plain_text_import_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("Import Plain text & .docx Files", visible=True): with gr.Row(): with gr.Column(): - gr.Markdown("# Import Markdown(`.md`)/Text(`.txt`)/rtf & `.docx` Files") - gr.Markdown("Upload a single file or a zip file containing multiple files") - import_file = gr.File(label="Upload file for import", file_types=[".md", ".txt", ".rtf", ".docx", ".zip"]) - title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content (for single files)") - author_input = gr.Textbox(label="Author", placeholder="Enter the author's name (for single files)") - keywords_input = gr.Textbox(label="Keywords", placeholder="Enter keywords, comma-separated") - system_prompt_input = gr.Textbox(label="System Prompt (for Summarization)", lines=3, - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath + gr.Markdown("# Import `.md`/`.txt`/`.rtf`/`.docx` Files & `.zip` collections of them.") + gr.Markdown("Upload multiple files or a zip file containing multiple files") - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded + # Updated to support multiple files + import_files = gr.File( + label="Upload files for import", + file_count="multiple", + file_types=[".md", ".txt", ".rtf", ".docx", ".zip"] + ) - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST]""", - ) - custom_prompt_input = gr.Textbox(label="Custom User Prompt", placeholder="Enter a custom user prompt for summarization (optional)") + # Optional metadata override fields + author_input = gr.Textbox( + label="Author Override (optional)", + placeholder="Enter author name to apply to all files" + ) + keywords_input = gr.Textbox( + label="Keywords", + placeholder="Enter keywords, comma-separated - will be applied to all files" + ) + system_prompt_input = gr.Textbox( + label="System Prompt (for Summarization)", + lines=3, + value=""" + You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] + """ + ) + custom_prompt_input = gr.Textbox( + label="Custom User Prompt", + placeholder="Enter a custom user prompt for summarization (optional)" + ) auto_summarize_checkbox = gr.Checkbox(label="Auto-summarize", value=False) + + # API configuration api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace"], - label="API for Auto-summarization" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" ) api_key_input = gr.Textbox(label="API Key", type="password") import_button = gr.Button("Import File(s)") - with gr.Column(): - import_output = gr.Textbox(label="Import Status") - - def import_plain_text_file(file_path, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): - try: - # Determine the file type and convert if necessary - file_extension = os.path.splitext(file_path)[1].lower() - if file_extension == '.rtf': - with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: - convert_file(file_path, 'md', outputfile=temp_file.name) - file_path = temp_file.name - elif file_extension == '.docx': - content = docx2txt.process(file_path) - else: - with open(file_path, 'r', encoding='utf-8') as file: - content = file.read() - - # Process the content - return import_data(content, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) - except Exception as e: - return f"Error processing file: {str(e)}" - - def process_plain_text_zip_file(zip_file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): - results = [] - with tempfile.TemporaryDirectory() as temp_dir: - with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: - zip_ref.extractall(temp_dir) - - for filename in os.listdir(temp_dir): - if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - file_path = os.path.join(temp_dir, filename) - result = import_plain_text_file(file_path, title, author, keywords, system_prompt, - user_prompt, auto_summarize, api_name, api_key) - results.append(f"File: {filename} - {result}") - - return "\n".join(results) - - def import_file_handler(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): - if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): - return import_plain_text_file(file.name, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key) - elif file.name.lower().endswith('.zip'): - return process_plain_text_zip_file(file, title, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key) - else: - return "Unsupported file type. Please upload a .md, .txt, .rtf, .docx file or a .zip file containing these file types." + with gr.Column(): + import_output = gr.Textbox(label="Import Status", lines=10) import_button.click( fn=import_file_handler, - inputs=[import_file, title_input, author_input, keywords_input, system_prompt_input, - custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input], + inputs=[ + import_files, + author_input, + keywords_input, + system_prompt_input, + custom_prompt_input, + auto_summarize_checkbox, + api_name_input, + api_key_input + ], outputs=import_output ) - return import_file, title_input, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output \ No newline at end of file + return import_files, author_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output + +# +# End of Plain_text_import.py +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Podcast_tab.py b/App_Function_Libraries/Gradio_UI/Podcast_tab.py index d6187e17937bc602928c9d44277925e2ab3cfeb2..eeeabd3c905fc47a56b3beabf7608b61f761577e 100644 --- a/App_Function_Libraries/Gradio_UI/Podcast_tab.py +++ b/App_Function_Libraries/Gradio_UI/Podcast_tab.py @@ -2,23 +2,38 @@ # Description: Gradio UI for ingesting podcasts into the database # # Imports +import logging # # External Imports import gradio as gr # # Local Imports from App_Function_Libraries.Audio.Audio_Files import process_podcast -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models, update_user_prompt +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name # ######################################################################################################################## # # Functions: - def create_podcast_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Podcast", visible=True): gr.Markdown("# Podcast Transcription and Ingestion", visible=True) + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): with gr.Column(): podcast_url_input = gr.Textbox(label="Podcast URL", placeholder="Enter the podcast URL here") @@ -35,54 +50,130 @@ def create_podcast_tab(): keep_timestamps_input = gr.Checkbox(label="Keep Timestamps", value=True) with gr.Row(): - podcast_custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", - value=False, - visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", - value=False, - visible=True) + podcast_custom_prompt_checkbox = gr.Checkbox( + label="Use a Custom Prompt", + value=False, + visible=True + ) + preset_prompt_checkbox = gr.Checkbox( + label="Use a pre-set Prompt", + value=False, + visible=True + ) + with gr.Row(): - preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False) + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) with gr.Row(): - podcast_custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + with gr.Row(): - system_prompt_input = gr.Textbox(label="System Prompt", - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] -**Bulleted Note Creation Guidelines** - -**Headings**: -- Based on referenced topics, not categories like quotes or terms -- Surrounded by **bold** formatting -- Not listed as bullet points -- No space between headings and list items underneath - -**Emphasis**: -- **Important terms** set in bold font -- **Text ending in a colon**: also bolded - -**Review**: -- Ensure adherence to specified format -- Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] -""", - lines=3, - visible=False) + podcast_custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=10, + visible=False + ) + with gr.Row(): + system_prompt_input = gr.Textbox( + label="System Prompt", + value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhere to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """, + lines=10, + visible=False + ) + # Handle custom prompt checkbox change podcast_custom_prompt_checkbox.change( fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[podcast_custom_prompt_checkbox], outputs=[podcast_custom_prompt_input, system_prompt_input] ) + + # Handle preset prompt checkbox change + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + # Pagination button functions + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(choices=prompts), + gr.update(value=page_display_text), + current_page + ) + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) + # Update prompts when a preset is selected def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -91,16 +182,16 @@ def create_podcast_tab(): ) preset_prompt.change( - update_prompts, - inputs=preset_prompt, + fn=update_prompts, + inputs=[preset_prompt], outputs=[podcast_custom_prompt_input, system_prompt_input] ) + # Refactored API selection dropdown podcast_api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", "Llama.cpp", - "Kobold", "Ooba", "Tabbyapi", "VLLM","ollama", "HuggingFace", "Custom-OpenAI-API"], - value=None, - label="API Name for Summarization (Optional)" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" ) podcast_api_key_input = gr.Textbox(label="API Key (if required)", type="password") podcast_whisper_model_input = gr.Dropdown(choices=whisper_models, value="medium", label="Whisper Model") @@ -151,13 +242,37 @@ def create_podcast_tab(): podcast_process_button.click( fn=process_podcast, - inputs=[podcast_url_input, podcast_title_input, podcast_author_input, - podcast_keywords_input, podcast_custom_prompt_input, podcast_api_name_input, - podcast_api_key_input, podcast_whisper_model_input, keep_original_input, - enable_diarization_input, use_cookies_input, cookies_input, - chunk_method, max_chunk_size, chunk_overlap, use_adaptive_chunking, - use_multi_level_chunking, chunk_language, keep_timestamps_input], - outputs=[podcast_progress_output, podcast_transcription_output, podcast_summary_output, - podcast_title_input, podcast_author_input, podcast_keywords_input, podcast_error_output, - download_transcription, download_summary] + inputs=[ + podcast_url_input, + podcast_title_input, + podcast_author_input, + podcast_keywords_input, + podcast_custom_prompt_input, + podcast_api_name_input, + podcast_api_key_input, + podcast_whisper_model_input, + keep_original_input, + enable_diarization_input, + use_cookies_input, + cookies_input, + chunk_method, + max_chunk_size, + chunk_overlap, + use_adaptive_chunking, + use_multi_level_chunking, + chunk_language, + keep_timestamps_input, + system_prompt_input # Include system prompt input + ], + outputs=[ + podcast_progress_output, + podcast_transcription_output, + podcast_summary_output, + podcast_title_input, + podcast_author_input, + podcast_keywords_input, + podcast_error_output, + download_transcription, + download_summary + ] ) \ No newline at end of file diff --git a/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py b/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py index 861ba53d74ac89fea6167125b50090fcd72082ae..9c7f215619b5701e5e325881aab98c72061d8f96 100644 --- a/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py +++ b/App_Function_Libraries/Gradio_UI/Prompt_Suggestion_tab.py @@ -1,11 +1,14 @@ # Description: Gradio UI for Creating and Testing new Prompts # # Imports +import logging + import gradio as gr -from App_Function_Libraries.Chat import chat -from App_Function_Libraries.DB.SQLite_DB import add_or_update_prompt +from App_Function_Libraries.Chat.Chat_Functions import chat +from App_Function_Libraries.DB.DB_Manager import add_or_update_prompt from App_Function_Libraries.Prompt_Engineering.Prompt_Engineering import generate_prompt, test_generated_prompt +from App_Function_Libraries.Utils.Utils import format_api_name, global_api_endpoints, default_api_endpoint # @@ -18,6 +21,16 @@ from App_Function_Libraries.Prompt_Engineering.Prompt_Engineering import generat # Gradio tab for prompt suggestion and testing def create_prompt_suggestion_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Prompt Suggestion/Creation", visible=True): gr.Markdown("# Generate and Test AI Prompts with the Metaprompt Approach") @@ -30,11 +43,11 @@ def create_prompt_suggestion_tab(): placeholder="E.g., CUSTOMER_COMPLAINT, COMPANY_NAME") # API-related inputs + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=["OpenAI", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", "Llama.cpp", - "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", "Custom-OpenAI-API"], - label="API Provider", - value="OpenAI" # Default selection + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Analysis (Optional)" ) api_key_input = gr.Textbox(label="API Key", placeholder="Enter your API key (if required)", diff --git a/App_Function_Libraries/Gradio_UI/Prompts_tab.py b/App_Function_Libraries/Gradio_UI/Prompts_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..d45de87e2c052bcd30374347454925ccf5d7312a --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Prompts_tab.py @@ -0,0 +1,297 @@ +# Prompts_tab.py +# Description: This file contains the code for the prompts tab in the Gradio UI +# +# Imports +import html +import logging + +# +# External Imports +import gradio as gr +# +# Local Imports +from App_Function_Libraries.DB.DB_Manager import fetch_prompt_details, list_prompts +# +#################################################################################################### +# +# Functions: + +def create_prompt_view_tab(): + with gr.TabItem("View Prompt Database", visible=True): + gr.Markdown("# View Prompt Database Entries") + with gr.Row(): + with gr.Column(): + entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) + page_number = gr.Number(value=1, label="Page Number", precision=0) + view_button = gr.Button("View Page") + previous_page_button = gr.Button("Previous Page", visible=True) + next_page_button = gr.Button("Next Page", visible=True) + pagination_info = gr.Textbox(label="Pagination Info", interactive=False) + prompt_selector = gr.Dropdown(label="Select Prompt to View", choices=[]) + with gr.Column(): + results_table = gr.HTML() + selected_prompt_display = gr.HTML() + + # Function to view database entries + def view_database(page, entries_per_page): + try: + # Use list_prompts to get prompts and total pages + prompts, total_pages, current_page = list_prompts(page=int(page), per_page=int(entries_per_page)) + + table_html = "" + table_html += "" + prompt_choices = [] + for prompt_name in prompts: + details = fetch_prompt_details(prompt_name) + if details: + title, author, _, _, _, _ = details + author = author or "Unknown" # Handle None author + table_html += f"" + prompt_choices.append(prompt_name) # Using prompt_name as value + table_html += "
TitleAuthor
{html.escape(title)}{html.escape(author)}
" + + # Get total prompts if possible + total_prompts = total_pages * int(entries_per_page) # This might overestimate if the last page is not full + + pagination = f"Page {current_page} of {total_pages} (Total prompts: {total_prompts})" + + return table_html, pagination, total_pages, prompt_choices + except Exception as e: + return f"

Error fetching prompts: {e}

", "Error", 0, [] + + # Function to update page content + def update_page(page, entries_per_page): + results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) + page = int(page) + next_disabled = page >= total_pages + prev_disabled = page <= 1 + return ( + results, + pagination, + page, + gr.update(visible=True, interactive=not prev_disabled), # previous_page_button + gr.update(visible=True, interactive=not next_disabled), # next_page_button + gr.update(choices=prompt_choices) + ) + + # Function to go to the next page + def go_to_next_page(current_page, entries_per_page): + next_page = int(current_page) + 1 + return update_page(next_page, entries_per_page) + + # Function to go to the previous page + def go_to_previous_page(current_page, entries_per_page): + previous_page = max(1, int(current_page) - 1) + return update_page(previous_page, entries_per_page) + + # Function to display selected prompt details + def display_selected_prompt(prompt_name): + details = fetch_prompt_details(prompt_name) + if details: + title, author, description, system_prompt, user_prompt, keywords = details + # Handle None values by converting them to empty strings + description = description or "" + system_prompt = system_prompt or "" + user_prompt = user_prompt or "" + author = author or "Unknown" + keywords = keywords or "" + + html_content = f""" +
+

{html.escape(title)}

by {html.escape(author)}

+

Description: {html.escape(description)}

+
+ System Prompt: +
{html.escape(system_prompt)}
+
+
+ User Prompt: +
{html.escape(user_prompt)}
+
+

Keywords: {html.escape(keywords)}

+
+ """ + return html_content + else: + return "

Prompt not found.

" + + # Event handlers + view_button.click( + fn=update_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[page_number, entries_per_page], + outputs=[results_table, pagination_info, page_number, previous_page_button, next_page_button, prompt_selector] + ) + + prompt_selector.change( + fn=display_selected_prompt, + inputs=[prompt_selector], + outputs=[selected_prompt_display] + ) + + + +def create_prompts_export_tab(): + """Creates a tab for exporting prompts database content with multiple format options""" + with gr.TabItem("Export Prompts", visible=True): + gr.Markdown("# Export Prompts Database Content") + + with gr.Row(): + with gr.Column(): + export_type = gr.Radio( + choices=["All Prompts", "Prompts by Keyword"], + label="Export Type", + value="All Prompts" + ) + + # Keyword selection for filtered export + with gr.Column(visible=False) as keyword_col: + keyword_input = gr.Textbox( + label="Enter Keywords (comma-separated)", + placeholder="Enter keywords to filter prompts..." + ) + + # Export format selection + export_format = gr.Radio( + choices=["CSV", "Markdown (ZIP)"], + label="Export Format", + value="CSV" + ) + + # Export options + include_options = gr.CheckboxGroup( + choices=[ + "Include System Prompts", + "Include User Prompts", + "Include Details", + "Include Author", + "Include Keywords" + ], + label="Export Options", + value=["Include Keywords", "Include Author"] + ) + + # Markdown-specific options (only visible when Markdown is selected) + with gr.Column(visible=False) as markdown_options_col: + markdown_template = gr.Radio( + choices=[ + "Basic Template", + "Detailed Template", + "Custom Template" + ], + label="Markdown Template", + value="Basic Template" + ) + custom_template = gr.Textbox( + label="Custom Template", + placeholder="Use {title}, {author}, {details}, {system}, {user}, {keywords} as placeholders", + visible=False + ) + + export_button = gr.Button("Export Prompts") + + with gr.Column(): + export_status = gr.Textbox(label="Export Status", interactive=False) + export_file = gr.File(label="Download Export") + + def update_ui_visibility(export_type, format_choice, template_choice): + """Update UI elements visibility based on selections""" + show_keywords = export_type == "Prompts by Keyword" + show_markdown_options = format_choice == "Markdown (ZIP)" + show_custom_template = template_choice == "Custom Template" and show_markdown_options + + return [ + gr.update(visible=show_keywords), # keyword_col + gr.update(visible=show_markdown_options), # markdown_options_col + gr.update(visible=show_custom_template) # custom_template + ] + + def handle_export(export_type, keywords, export_format, options, markdown_template, custom_template): + """Handle the export process based on selected options""" + try: + # Parse options + include_system = "Include System Prompts" in options + include_user = "Include User Prompts" in options + include_details = "Include Details" in options + include_author = "Include Author" in options + include_keywords = "Include Keywords" in options + + # Handle keyword filtering + keyword_list = None + if export_type == "Prompts by Keyword" and keywords: + keyword_list = [k.strip() for k in keywords.split(",") if k.strip()] + + # Get the appropriate template + template = None + if export_format == "Markdown (ZIP)": + if markdown_template == "Custom Template": + template = custom_template + else: + template = markdown_template + + # Perform export + from App_Function_Libraries.DB.Prompts_DB import export_prompts + status, file_path = export_prompts( + export_format=export_format.split()[0].lower(), # 'csv' or 'markdown' + filter_keywords=keyword_list, + include_system=include_system, + include_user=include_user, + include_details=include_details, + include_author=include_author, + include_keywords=include_keywords, + markdown_template=template + ) + + return status, file_path + + except Exception as e: + error_msg = f"Export failed: {str(e)}" + logging.error(error_msg) + return error_msg, None + + # Event handlers + export_type.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_format.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + markdown_template.change( + fn=lambda t, f, m: update_ui_visibility(t, f, m), + inputs=[export_type, export_format, markdown_template], + outputs=[keyword_col, markdown_options_col, custom_template] + ) + + export_button.click( + fn=handle_export, + inputs=[ + export_type, + keyword_input, + export_format, + include_options, + markdown_template, + custom_template + ], + outputs=[export_status, export_file] + ) + +# +# End of Prompts_tab.py +#################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/RAG_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_Chat_tab.py index 5a470effefca66d867e44960cf90da849ca5d38e..b213c832a972df0e357023c31822c58d31d40d17 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_Chat_tab.py @@ -10,12 +10,26 @@ import gradio as gr # Local Imports from App_Function_Libraries.RAG.RAG_Library_2 import enhanced_rag_pipeline +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name + + # ######################################################################################################################## # # Functions: def create_rag_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("RAG Search", visible=True): gr.Markdown("# Retrieval-Augmented Generation (RAG) Search") @@ -36,10 +50,11 @@ def create_rag_tab(): visible=False ) + # Refactored API selection dropdown api_choice = gr.Dropdown( - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace"], - label="Select API for RAG", - value="OpenAI" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Response (Optional)" ) search_button = gr.Button("Search") diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index b25b58bd04df4e10ecdd16bc6e0129d3854b32ba..9764e3fb543348457c0312880a42b58637309b73 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -6,6 +6,7 @@ import csv import logging import json import os +import re from datetime import datetime # # External Imports @@ -14,32 +15,39 @@ import gradio as gr # # Local Imports from App_Function_Libraries.Books.Book_Ingestion_Lib import read_epub -from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords -from App_Function_Libraries.DB.RAG_QA_Chat_DB import ( - save_notes, - add_keywords_to_note, - start_new_conversation, - save_message, - search_conversations_by_keywords, - load_chat_history, - get_all_conversations, - get_note_by_id, - get_notes_by_keywords, - get_notes_by_keyword_collection, - update_note, - clear_keywords_from_note, get_notes, get_keywords_for_note, delete_conversation, delete_note, execute_query, - add_keywords_to_conversation, fetch_all_notes, fetch_all_conversations, fetch_conversations_by_ids, - fetch_notes_by_ids, -) +from App_Function_Libraries.DB.Character_Chat_DB import search_character_chat, search_character_cards +from App_Function_Libraries.DB.DB_Manager import DatabaseError, get_paginated_files, add_media_with_keywords, \ + get_all_conversations, get_note_by_id, get_notes_by_keywords, start_new_conversation, update_note, save_notes, \ + clear_keywords_from_note, add_keywords_to_note, load_chat_history, save_message, add_keywords_to_conversation, \ + get_keywords_for_note, delete_note, search_conversations_by_keywords, get_conversation_title, delete_conversation, \ + update_conversation_title, fetch_all_conversations, fetch_all_notes, fetch_conversations_by_ids, fetch_notes_by_ids, \ + search_media_db, search_notes_titles, list_prompts +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_notes, delete_messages_in_conversation, search_rag_notes, \ + search_rag_chat, get_conversation_rating, set_conversation_rating +from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_user_prompt from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_text_and_format_from_pdf from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.RAG.RAG_QA_Chat import search_database, rag_qa_chat +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name, \ + load_comprehensive_config + + # ######################################################################################################################## # # Functions: def create_rag_qa_chat_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("RAG QA Chat", visible=True): gr.Markdown("# RAG QA Chat") @@ -47,18 +55,53 @@ def create_rag_qa_chat_tab(): "page": 1, "context_source": "Entire Media Database", "conversation_messages": [], + "conversation_id": None }) note_state = gr.State({"note_id": None}) + def auto_save_conversation(message, response, state_value, auto_save_enabled): + """Automatically save the conversation if auto-save is enabled""" + try: + if not auto_save_enabled: + return state_value + + conversation_id = state_value.get("conversation_id") + if not conversation_id: + # Create new conversation with default title + title = "Auto-saved Conversation " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + conversation_id = start_new_conversation(title=title) + state_value = state_value.copy() + state_value["conversation_id"] = conversation_id + + # Save the messages + save_message(conversation_id, "user", message) + save_message(conversation_id, "assistant", response) + + return state_value + except Exception as e: + logging.error(f"Error in auto-save: {str(e)}") + return state_value + # Update the conversation list function def update_conversation_list(): conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] + choices = [ + f"{conversation['title']} (ID: {conversation['conversation_id']}) - Rating: {conversation['rating'] or 'Not Rated'}" + for conversation in conversations + ] return choices with gr.Row(): with gr.Column(scale=1): + # FIXME - Offer the user to search 2+ databases at once + database_types = ["Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards"] + db_choice = gr.CheckboxGroup( + label="Select Database(s)", + choices=database_types, + value=["Media DB"], + interactive=True + ) context_source = gr.Radio( ["All Files in the Database", "Search Database", "Upload File"], label="Context Source", @@ -71,19 +114,52 @@ def create_rag_qa_chat_tab(): next_page_btn = gr.Button("Next Page") page_info = gr.HTML("Page 1") top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) - keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) + keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", value="rag_qa_default_keyword" ,visible=True) use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True) - # with gr.Row(): - # page_number = gr.Number(value=1, label="Page", precision=0) - # page_size = gr.Number(value=20, label="Items per page", precision=0) - # total_pages = gr.Number(label="Total Pages", interactive=False) + config = load_comprehensive_config() + auto_save_value = config.getboolean('auto-save', 'save_character_chats', fallback=False) + auto_save_checkbox = gr.Checkbox( + label="Save chats automatically", + value=auto_save_value, + info="When enabled, conversations will be saved automatically after each message" + ) + + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=10) + preset_prompt_checkbox = gr.Checkbox( + label="View Custom Prompts(have to copy/paste them)", + value=False, + visible=True + ) + + with gr.Row(visible=False) as preset_prompt_controls: + prev_prompt_page = gr.Button("Previous") + current_prompt_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_prompt_page = gr.Button("Next") + current_prompt_page_state = gr.State(value=1) + + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=initial_prompts, + visible=False + ) + user_prompt = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) + + system_prompt_input = gr.Textbox( + label="System Prompt", + lines=3, + visible=False + ) search_query = gr.Textbox(label="Search Query", visible=False) search_button = gr.Button("Search", visible=False) search_results = gr.Dropdown(label="Search Results", choices=[], visible=False) - # FIXME - Add pages for search results handling file_upload = gr.File( label="Upload File", visible=False, @@ -95,34 +171,28 @@ def create_rag_qa_chat_tab(): load_conversation = gr.Dropdown( label="Load Conversation", choices=update_conversation_list() - ) + ) new_conversation = gr.Button("New Conversation") save_conversation_button = gr.Button("Save Conversation") conversation_title = gr.Textbox( - label="Conversation Title", placeholder="Enter a title for the new conversation" + label="Conversation Title", + placeholder="Enter a title for the new conversation" ) keywords = gr.Textbox(label="Keywords (comma-separated)", visible=True) + # Add the rating display and input + rating_display = gr.Markdown(value="", visible=False) + rating_input = gr.Radio( + choices=["1", "2", "3"], + label="Rate this Conversation (1-3 stars)", + visible=False + ) + + # Refactored API selection dropdown api_choice = gr.Dropdown( - choices=[ - "Local-LLM", - "OpenAI", - "Anthropic", - "Cohere", - "Groq", - "DeepSeek", - "Mistral", - "OpenRouter", - "Llama.cpp", - "Kobold", - "Ooba", - "Tabbyapi", - "VLLM", - "ollama", - "HuggingFace", - ], - label="Select API for RAG", - value="OpenAI", + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Chat Response (Optional)" ) with gr.Row(): @@ -145,6 +215,8 @@ def create_rag_qa_chat_tab(): clear_notes_btn = gr.Button("Clear Current Note text") new_note_btn = gr.Button("New Note") + # FIXME - Change from only keywords to generalized search + search_notes_title = gr.Textbox(label="Search Notes by Title") search_notes_by_keyword = gr.Textbox(label="Search Notes by Keyword") search_notes_button = gr.Button("Search Notes") note_results = gr.Dropdown(label="Notes", choices=[]) @@ -152,8 +224,58 @@ def create_rag_qa_chat_tab(): loading_indicator = gr.HTML("Loading...", visible=False) status_message = gr.HTML() + auto_save_status = gr.HTML() + + # Function Definitions + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + def toggle_preset_prompt(checkbox_value): + return ( + gr.update(visible=checkbox_value), + gr.update(visible=checkbox_value), + gr.update(visible=False), + gr.update(visible=False) + ) + + prev_prompt_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + next_prompt_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_prompt_page_state], + outputs=[preset_prompt, current_prompt_page_text, current_prompt_page_state] + ) + + preset_prompt.change( + update_prompts, + inputs=preset_prompt, + outputs=[user_prompt, system_prompt_input] + ) + + preset_prompt_checkbox.change( + toggle_preset_prompt, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, preset_prompt_controls, user_prompt, system_prompt_input] + ) def update_state(state, **kwargs): new_state = state.copy() @@ -168,18 +290,28 @@ def create_rag_qa_chat_tab(): outputs=[note_title, notes, note_state] ) - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"Note {note_id} ({timestamp})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_by_keyword], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[note_results] ) @@ -201,31 +333,69 @@ def create_rag_qa_chat_tab(): def save_notes_function(note_title_text, notes_content, keywords_content, note_state_value, state_value): """Save the notes and associated keywords to the database.""" - conversation_id = state_value.get("conversation_id") - note_id = note_state_value["note_id"] - if conversation_id and notes_content: + logging.info(f"Starting save_notes_function with state: {state_value}") + logging.info(f"Note title: {note_title_text}") + logging.info(f"Notes content length: {len(notes_content) if notes_content else 0}") + + try: + # Check current state + conversation_id = state_value.get("conversation_id") + logging.info(f"Current conversation_id: {conversation_id}") + + # Create new conversation if none exists + if not conversation_id: + logging.info("No conversation ID found, creating new conversation") + conversation_title = note_title_text if note_title_text else "Untitled Conversation" + conversation_id = start_new_conversation(title=conversation_title) + state_value = state_value.copy() # Create a new copy of the state + state_value["conversation_id"] = conversation_id + logging.info(f"Created new conversation with ID: {conversation_id}") + + if not notes_content: + logging.warning("No notes content provided") + return notes_content, note_state_value, state_value, gr.update( + value="

Cannot save empty notes.

") + + # Save or update note + note_id = note_state_value.get("note_id") if note_id: - # Update existing note + logging.info(f"Updating existing note with ID: {note_id}") update_note(note_id, note_title_text, notes_content) else: - # Save new note - note_id = save_notes(conversation_id, note_title_text, notes_content) - note_state_value["note_id"] = note_id + logging.info(f"Creating new note for conversation: {conversation_id}") + note_id = save_notes(conversation_id, note_title_text or "Untitled Note", notes_content) + note_state_value = {"note_id": note_id} + logging.info(f"Created new note with ID: {note_id}") + + # Handle keywords if keywords_content: - # Clear existing keywords and add new ones + logging.info("Processing keywords") clear_keywords_from_note(note_id) - add_keywords_to_note(note_id, [kw.strip() for kw in keywords_content.split(',')]) + keywords = [kw.strip() for kw in keywords_content.split(',')] + add_keywords_to_note(note_id, keywords) + logging.info(f"Added keywords: {keywords}") - logging.info("Notes and keywords saved successfully!") - return notes_content, note_state_value - else: - logging.warning("No conversation ID or notes to save.") - return "", note_state_value + logging.info("Notes saved successfully") + return ( + notes_content, + note_state_value, + state_value, + gr.update(value="

Notes saved successfully!

") + ) + + except Exception as e: + logging.error(f"Error in save_notes_function: {str(e)}", exc_info=True) + return ( + notes_content, + note_state_value, + state_value, + gr.update(value=f"

Error saving notes: {str(e)}

") + ) save_notes_btn.click( save_notes_function, inputs=[note_title, notes, keywords_for_notes, note_state, state], - outputs=[notes, note_state] + outputs=[notes, note_state, state, status_message] ) def clear_notes_function(): @@ -237,83 +407,112 @@ def create_rag_qa_chat_tab(): outputs=[notes, note_state] ) - def update_conversation_list(): - conversations, total_pages, total_count = get_all_conversations() - choices = [f"{title} (ID: {conversation_id})" for conversation_id, title in conversations] - return choices - # Initialize the conversation list load_conversation.choices = update_conversation_list() def load_conversation_history(selected_conversation, state_value): - if selected_conversation: - conversation_id = selected_conversation.split('(ID: ')[1][:-1] + try: + if not selected_conversation: + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + # Extract conversation ID + match = re.search(r'\(ID: ([0-9a-fA-F\-]+)\)', selected_conversation) + if not match: + logging.error(f"Invalid conversation format: {selected_conversation}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) + conversation_id = match.group(1) chat_data, total_pages_val, _ = load_chat_history(conversation_id, 1, 50) - # Convert chat data to list of tuples (user_message, assistant_response) + # Update state with valid conversation id + updated_state = state_value.copy() + updated_state["conversation_id"] = conversation_id + updated_state["conversation_messages"] = chat_data + # Format chat history history = [] for role, content in chat_data: if role == 'user': history.append((content, '')) - else: - if history: - history[-1] = (history[-1][0], content) - else: - history.append(('', content)) - # Retrieve notes + elif history: + history[-1] = (history[-1][0], content) + # Fetch and display the conversation rating + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_text = f"**Current Rating:** {rating} star(s)" + rating_display_update = gr.update(value=rating_text, visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) notes_content = get_notes(conversation_id) - updated_state = update_state(state_value, conversation_id=conversation_id, page=1, - conversation_messages=[]) - return history, updated_state, "\n".join(notes_content) - return [], state_value, "" + return history, updated_state, "\n".join( + notes_content) if notes_content else "", rating_display_update, rating_input_update + except Exception as e: + logging.error(f"Error loading conversation: {str(e)}") + return [], state_value, "", gr.update(value="", visible=False), gr.update(visible=False) load_conversation.change( load_conversation_history, inputs=[load_conversation, state], - outputs=[chatbot, state, notes] + outputs=[chatbot, state, notes, rating_display, rating_input] ) # Modify save_conversation_function to use gr.update() - def save_conversation_function(conversation_title_text, keywords_text, state_value): + def save_conversation_function(conversation_title_text, keywords_text, rating_value, state_value): conversation_messages = state_value.get("conversation_messages", []) + conversation_id = state_value.get("conversation_id") if not conversation_messages: return gr.update( value="

No conversation to save.

" - ), state_value, gr.update() - # Start a new conversation in the database - new_conversation_id = start_new_conversation( - conversation_title_text if conversation_title_text else "Untitled Conversation" - ) + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Start a new conversation in the database if not existing + if not conversation_id: + conversation_id = start_new_conversation( + conversation_title_text if conversation_title_text else "Untitled Conversation" + ) + else: + # Update the conversation title if it has changed + update_conversation_title(conversation_id, conversation_title_text) # Save the messages for role, content in conversation_messages: - save_message(new_conversation_id, role, content) + save_message(conversation_id, role, content) # Save keywords if provided if keywords_text: - add_keywords_to_conversation(new_conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + add_keywords_to_conversation(conversation_id, [kw.strip() for kw in keywords_text.split(',')]) + # Save the rating if provided + try: + if rating_value: + set_conversation_rating(conversation_id, int(rating_value)) + except ValueError as ve: + logging.error(f"Invalid rating value: {ve}") + return gr.update( + value=f"

Invalid rating: {ve}

" + ), state_value, gr.update(), gr.update(value="", visible=False), gr.update(visible=False) + # Update state - updated_state = update_state(state_value, conversation_id=new_conversation_id) + updated_state = update_state(state_value, conversation_id=conversation_id) # Update the conversation list conversation_choices = update_conversation_list() + # Reset rating display and input + rating_display_update = gr.update(value=f"**Current Rating:** {rating_value} star(s)", visible=True) + rating_input_update = gr.update(value=rating_value, visible=True) return gr.update( value="

Conversation saved successfully.

" - ), updated_state, gr.update(choices=conversation_choices) + ), updated_state, gr.update(choices=conversation_choices), rating_display_update, rating_input_update save_conversation_button.click( save_conversation_function, - inputs=[conversation_title, keywords, state], - outputs=[status_message, state, load_conversation] + inputs=[conversation_title, keywords, rating_input, state], + outputs=[status_message, state, load_conversation, rating_display, rating_input] ) def start_new_conversation_wrapper(title, state_value): - # Reset the state with no conversation_id - updated_state = update_state(state_value, conversation_id=None, page=1, - conversation_messages=[]) - # Clear the chat history - return [], updated_state + # Reset the state with no conversation_id and empty conversation messages + updated_state = update_state(state_value, conversation_id=None, page=1, conversation_messages=[]) + # Clear the chat history and reset rating components + return [], updated_state, gr.update(value="", visible=False), gr.update(value=None, visible=False) new_conversation.click( start_new_conversation_wrapper, inputs=[conversation_title, state], - outputs=[chatbot, state] + outputs=[chatbot, state, rating_display, rating_input] ) def update_file_list(page): @@ -328,11 +527,12 @@ def create_rag_qa_chat_tab(): return update_file_list(max(1, current_page - 1)) def update_context_source(choice): + # Update visibility based on context source choice return { existing_file: gr.update(visible=choice == "Existing File"), - prev_page_btn: gr.update(visible=choice == "Existing File"), - next_page_btn: gr.update(visible=choice == "Existing File"), - page_info: gr.update(visible=choice == "Existing File"), + prev_page_btn: gr.update(visible=choice == "Search Database"), + next_page_btn: gr.update(visible=choice == "Search Database"), + page_info: gr.update(visible=choice == "Search Database"), search_query: gr.update(visible=choice == "Search Database"), search_button: gr.update(visible=choice == "Search Database"), search_results: gr.update(visible=choice == "Search Database"), @@ -352,17 +552,36 @@ def create_rag_qa_chat_tab(): context_source.change(lambda choice: update_file_list(1) if choice == "Existing File" else (gr.update(), gr.update(), 1), inputs=[context_source], outputs=[existing_file, page_info, file_page]) - def perform_search(query): + def perform_search(query, selected_databases, keywords): try: - results = search_database(query) + results = [] + + # Iterate over selected database types and perform searches accordingly + for database_type in selected_databases: + if database_type == "Media DB": + # FIXME - check for existence of keywords before setting as search field + search_fields = ["title", "content", "keywords"] + results += search_media_db(query, search_fields, keywords, page=1, results_per_page=25) + elif database_type == "RAG Chat": + results += search_rag_chat(query) + elif database_type == "RAG Notes": + results += search_rag_notes(query) + elif database_type == "Character Chat": + results += search_character_chat(query) + elif database_type == "Character Cards": + results += search_character_cards(query) + + # Remove duplicate results if necessary + results = list(set(results)) return gr.update(choices=results) except Exception as e: gr.Error(f"Error performing search: {str(e)}") return gr.update(choices=[]) + # Click Event for the DB Search Button search_button.click( perform_search, - inputs=[search_query], + inputs=[search_query, db_choice, keywords_input], outputs=[search_results] ) @@ -384,17 +603,22 @@ Rewritten Question:""" logging.info(f"Rephrased question: {rephrased_question}") return rephrased_question.strip() - def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, - convert_to_text, keywords, api_choice, use_query_rewriting, state_value, - keywords_input, top_k_input, use_re_ranking): + # FIXME - RAG DB selection + def rag_qa_chat_wrapper( + message, history, context_source, existing_file, search_results, file_upload, + convert_to_text, keywords, api_choice, use_query_rewriting, state_value, + keywords_input, top_k_input, use_re_ranking, db_choices, auto_save_enabled + ): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") logging.info(f"API choice: {api_choice}") logging.info(f"Query rewriting: {'enabled' if use_query_rewriting else 'disabled'}") + logging.info(f"Selected DB Choices: {db_choices}") # Show loading indicator - yield history, "", gr.update(visible=True), state_value + yield history, "", gr.update(visible=True), state_value, gr.update(visible=False), gr.update( + visible=False) conversation_id = state_value.get("conversation_id") conversation_messages = state_value.get("conversation_messages", []) @@ -408,12 +632,12 @@ Rewritten Question:""" state_value["conversation_messages"] = conversation_messages # Ensure api_choice is a string - api_choice = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice - logging.info(f"Resolved API choice: {api_choice}") + api_choice_str = api_choice.value if isinstance(api_choice, gr.components.Dropdown) else api_choice + logging.info(f"Resolved API choice: {api_choice_str}") # Only rephrase the question if it's not the first query and query rewriting is enabled if len(history) > 0 and use_query_rewriting: - rephrased_question = rephrase_question(history, message, api_choice) + rephrased_question = rephrase_question(history, message, api_choice_str) logging.info(f"Original question: {message}") logging.info(f"Rephrased question: {rephrased_question}") else: @@ -421,18 +645,20 @@ Rewritten Question:""" logging.info(f"Using original question: {message}") if context_source == "All Files in the Database": - # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input, - use_re_ranking) + # Use the enhanced_rag_pipeline to search the selected databases + context = enhanced_rag_pipeline( + rephrased_question, api_choice_str, keywords_input, top_k_input, use_re_ranking, + database_types=db_choices # Pass the list of selected databases + ) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" logging.info(f"Using search result with context: {context}") - else: # Upload File + else: + # Upload File logging.info("Processing uploaded file") if file_upload is None: raise ValueError("No file uploaded") - # Process the uploaded file file_path = file_upload.name file_name = os.path.basename(file_path) @@ -445,7 +671,6 @@ Rewritten Question:""" logging.info("Reading file content") with open(file_path, 'r', encoding='utf-8') as f: content = f.read() - logging.info(f"File content length: {len(content)} characters") # Process keywords @@ -467,18 +692,17 @@ Rewritten Question:""" author='Unknown', ingestion_date=datetime.now().strftime('%Y-%m-%d') ) - logging.info(f"Result from add_media_with_keywords: {result}") if isinstance(result, tuple): media_id, _ = result else: media_id = result - context = f"media_id:{media_id}" logging.info(f"Context for uploaded file: {context}") logging.info("Calling rag_qa_chat function") - new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice) + new_history, response = rag_qa_chat(rephrased_question, history, context, api_choice_str) + # Log first 100 chars of response logging.info(f"Response received from rag_qa_chat: {response[:100]}...") @@ -490,7 +714,8 @@ Rewritten Question:""" state_value["conversation_messages"] = conversation_messages # Update the state - state_value["conversation_messages"] = conversation_messages + updated_state = auto_save_conversation(message, response, state_value, auto_save_enabled) + updated_state["conversation_messages"] = conversation_messages # Safely update history if new_history: @@ -498,24 +723,43 @@ Rewritten Question:""" else: new_history = [(message, response)] + # Get the current rating and update display + conversation_id = updated_state.get("conversation_id") + if conversation_id: + rating = get_conversation_rating(conversation_id) + if rating is not None: + rating_display_update = gr.update(value=f"**Current Rating:** {rating} star(s)", visible=True) + rating_input_update = gr.update(value=str(rating), visible=True) + else: + rating_display_update = gr.update(value="**Current Rating:** Not Rated", visible=True) + rating_input_update = gr.update(value=None, visible=True) + else: + rating_display_update = gr.update(value="", visible=False) + rating_input_update = gr.update(value=None, visible=False) + gr.Info("Response generated successfully") logging.info("rag_qa_chat_wrapper completed successfully") - yield new_history, "", gr.update(visible=False), state_value # Include state_value in outputs + yield new_history, "", gr.update( + visible=False), updated_state, rating_display_update, rating_input_update + except ValueError as e: logging.error(f"Input error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Input error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except DatabaseError as e: logging.error(f"Database error in rag_qa_chat_wrapper: {str(e)}") gr.Error(f"Database error: {str(e)}") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) except Exception as e: logging.error(f"Unexpected error in rag_qa_chat_wrapper: {e}", exc_info=True) gr.Error("An unexpected error occurred. Please try again later.") - yield history, "", gr.update(visible=False), state_value + yield history, "", gr.update(visible=False), state_value, gr.update(visible=False), gr.update( + visible=False) def clear_chat_history(): - return [], "" + return [], "", gr.update(value="", visible=False), gr.update(value=None, visible=False) submit.click( rag_qa_chat_wrapper, @@ -532,14 +776,17 @@ Rewritten Question:""" use_query_rewriting, state, keywords_input, - top_k_input + top_k_input, + use_re_ranking, + db_choice, + auto_save_checkbox ], - outputs=[chatbot, msg, loading_indicator, state], + outputs=[chatbot, msg, loading_indicator, state, rating_display, rating_input], ) clear_chat.click( clear_chat_history, - outputs=[chatbot, msg] + outputs=[chatbot, msg, rating_display, rating_input] ) return ( @@ -560,12 +807,10 @@ Rewritten Question:""" ) - def create_rag_qa_notes_management_tab(): # New Management Tab with gr.TabItem("Notes Management", visible=True): gr.Markdown("# RAG QA Notes Management") - management_state = gr.State({ "selected_conversation_id": None, "selected_note_id": None, @@ -574,7 +819,8 @@ def create_rag_qa_notes_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Notes - search_notes_input = gr.Textbox(label="Search Notes by Keywords") + search_notes_title = gr.Textbox(label="Search Notes by Title") + search_notes_by_keyword = gr.Textbox(label="Search Notes by Keywords") search_notes_button = gr.Button("Search Notes") notes_list = gr.Dropdown(label="Notes", choices=[]) @@ -583,24 +829,34 @@ def create_rag_qa_notes_management_tab(): delete_note_button = gr.Button("Delete Note") note_title_input = gr.Textbox(label="Note Title") note_content_input = gr.TextArea(label="Note Content", lines=20) - note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)") + note_keywords_input = gr.Textbox(label="Note Keywords (comma-separated)", value="default_note_keyword") save_note_button = gr.Button("Save Note") create_new_note_button = gr.Button("Create New Note") status_message = gr.HTML() # Function Definitions - def search_notes(keywords): + def search_notes(search_notes_title, keywords): if keywords: keywords_list = [kw.strip() for kw in keywords.split(',')] notes_data, total_pages, total_count = get_notes_by_keywords(keywords_list) - choices = [f"{title} (ID: {note_id})" for note_id, title, content, timestamp in notes_data] - return gr.update(choices=choices) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") + elif search_notes_title: + notes_data, total_pages, total_count = search_notes_titles(search_notes_title) + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"Found {total_count} notes") else: - return gr.update(choices=[]) + # This will now return all notes, ordered by timestamp + notes_data, total_pages, total_count = search_notes_titles("") + choices = [f"Note {note_id} - {title} ({timestamp})" for + note_id, title, content, timestamp, conversation_id in notes_data] + return gr.update(choices=choices, label=f"All notes ({total_count} total)") search_notes_button.click( search_notes, - inputs=[search_notes_input], + inputs=[search_notes_title, search_notes_by_keyword], outputs=[notes_list] ) @@ -664,7 +920,7 @@ def create_rag_qa_notes_management_tab(): # Reset state state_value["selected_note_id"] = None # Update notes list - updated_notes = search_notes("") + updated_notes = search_notes("", "") return updated_notes, gr.update(value="Note deleted successfully."), state_value else: return gr.update(), gr.update(value="No note selected."), state_value @@ -702,7 +958,20 @@ def create_rag_qa_chat_management_tab(): with gr.Row(): with gr.Column(scale=1): # Search Conversations - search_conversations_input = gr.Textbox(label="Search Conversations by Keywords") + with gr.Group(): + gr.Markdown("## Search Conversations") + title_search = gr.Textbox( + label="Search by Title", + placeholder="Enter title to search..." + ) + content_search = gr.Textbox( + label="Search in Chat Content", + placeholder="Enter text to search in messages..." + ) + keyword_search = gr.Textbox( + label="Filter by Keywords (comma-separated)", + placeholder="keyword1, keyword2, ..." + ) search_conversations_button = gr.Button("Search Conversations") conversations_list = gr.Dropdown(label="Conversations", choices=[]) new_conversation_button = gr.Button("New Conversation") @@ -716,26 +985,40 @@ def create_rag_qa_chat_management_tab(): status_message = gr.HTML() # Function Definitions - def search_conversations(keywords): - if keywords: - keywords_list = [kw.strip() for kw in keywords.split(',')] - conversations, total_pages, total_count = search_conversations_by_keywords(keywords_list) - else: - conversations, total_pages, total_count = get_all_conversations() + def search_conversations(title_query, content_query, keywords): + try: + # Parse keywords if provided + keywords_list = None + if keywords and keywords.strip(): + keywords_list = [kw.strip() for kw in keywords.split(',')] + + # Search using existing search_conversations_by_keywords function with all criteria + results, total_pages, total_count = search_conversations_by_keywords( + keywords=keywords_list, + title_query=title_query if title_query.strip() else None, + content_query=content_query if content_query.strip() else None + ) - # Build choices as list of titles (ensure uniqueness) - choices = [] - mapping = {} - for conversation_id, title in conversations: - display_title = f"{title} (ID: {conversation_id[:8]})" - choices.append(display_title) - mapping[display_title] = conversation_id + # Build choices as list of titles (ensure uniqueness) + choices = [] + mapping = {} + for conv in results: + conversation_id = conv['conversation_id'] + title = conv['title'] + display_title = f"{title} (ID: {conversation_id[:8]})" + choices.append(display_title) + mapping[display_title] = conversation_id - return gr.update(choices=choices), mapping + return gr.update(choices=choices), mapping + except Exception as e: + logging.error(f"Error in search_conversations: {str(e)}") + return gr.update(choices=[]), {} + + # Update the search button click event search_conversations_button.click( search_conversations, - inputs=[search_conversations_input], + inputs=[title_search, content_search, keyword_search], outputs=[conversations_list, conversation_mapping] ) @@ -892,19 +1175,18 @@ def create_rag_qa_chat_management_tab(): ] ) - def delete_messages_in_conversation(conversation_id): - """Helper function to delete all messages in a conversation.""" + def delete_messages_in_conversation_wrapper(conversation_id): + """Wrapper function to delete all messages in a conversation.""" try: - execute_query("DELETE FROM rag_qa_chats WHERE conversation_id = ?", (conversation_id,)) + delete_messages_in_conversation(conversation_id) logging.info(f"Messages in conversation '{conversation_id}' deleted successfully.") except Exception as e: logging.error(f"Error deleting messages in conversation '{conversation_id}': {e}") raise - def get_conversation_title(conversation_id): + def get_conversation_title_wrapper(conversation_id): """Helper function to get the conversation title.""" - query = "SELECT title FROM conversation_metadata WHERE conversation_id = ?" - result = execute_query(query, (conversation_id,)) + result = get_conversation_title(conversation_id) if result: return result[0][0] else: @@ -1034,19 +1316,6 @@ def create_export_data_tab(): ) - - -def update_conversation_title(conversation_id, new_title): - """Update the title of a conversation.""" - try: - query = "UPDATE conversation_metadata SET title = ? WHERE conversation_id = ?" - execute_query(query, (new_title, conversation_id)) - logging.info(f"Conversation '{conversation_id}' title updated to '{new_title}'") - except Exception as e: - logging.error(f"Error updating conversation title: {e}") - raise - - def convert_file_to_text(file_path): """Convert various file types to plain text.""" file_extension = os.path.splitext(file_path)[1].lower() diff --git a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py index 290736120151e3d25d0e0ff187f28e36e6d54f81..3016249f28c75fd6dbc1c3ea996ca5344eba7e0b 100644 --- a/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py +++ b/App_Function_Libraries/Gradio_UI/Re_summarize_tab.py @@ -10,19 +10,33 @@ import gradio as gr # # Local Imports from App_Function_Libraries.Chunk_Lib import improved_chunking_process -from App_Function_Libraries.DB.DB_Manager import update_media_content, load_preset_prompts +from App_Function_Libraries.DB.DB_Manager import update_media_content, list_prompts from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import fetch_item_details, fetch_items_by_keyword, \ fetch_items_by_content, fetch_items_by_title_or_url from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize_chunk -from App_Function_Libraries.Utils.Utils import load_comprehensive_config -# +from App_Function_Libraries.Utils.Utils import load_comprehensive_config, default_api_endpoint, global_api_endpoints, \ + format_api_name # ###################################################################################################################### # # Functions: def create_resummary_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + + # Get initial prompts for first page + initial_prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + with gr.TabItem("Re-Summarize", visible=True): gr.Markdown("# Re-Summarize Existing Content") with gr.Row(): @@ -36,9 +50,10 @@ def create_resummary_tab(): with gr.Row(): api_name_input = gr.Dropdown( - choices=["Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM","ollama", "HuggingFace"], - value="Local-LLM", label="API Name") + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" + ) api_key_input = gr.Textbox(label="API Key", placeholder="Enter your API key here", type="password") chunking_options_checkbox = gr.Checkbox(label="Use Chunking", value=False) @@ -55,9 +70,17 @@ def create_resummary_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Add pagination controls for preset prompts + with gr.Row(visible=False) as preset_prompt_controls: + prev_page = gr.Button("Previous") + current_page_text = gr.Text(f"Page {current_page} of {total_pages}") + next_page = gr.Button("Next") + current_page_state = gr.State(value=1) + with gr.Row(): preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=initial_prompts, visible=False) with gr.Row(): custom_prompt_input = gr.Textbox(label="Custom Prompt", @@ -86,6 +109,15 @@ def create_resummary_tab(): lines=3, visible=False) + def update_prompt_page(direction, current_page_val): + new_page = max(1, min(total_pages, current_page_val + direction)) + prompts, _, _ = list_prompts(page=new_page, per_page=10) + return ( + gr.update(choices=prompts), + gr.update(value=f"Page {new_page} of {total_pages}"), + new_page + ) + def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( @@ -93,6 +125,19 @@ def create_resummary_tab(): gr.update(value=prompts["system_prompt"], visible=True) ) + # Connect pagination buttons + prev_page.click( + lambda x: update_prompt_page(-1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + + next_page.click( + lambda x: update_prompt_page(1, x), + inputs=[current_page_state], + outputs=[preset_prompt, current_page_text, current_page_state] + ) + preset_prompt.change( update_prompts, inputs=preset_prompt, @@ -109,9 +154,9 @@ def create_resummary_tab(): outputs=[custom_prompt_input, system_prompt_input] ) preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, preset_prompt_controls] ) # Connect the UI elements @@ -140,7 +185,12 @@ def create_resummary_tab(): outputs=result_output ) - return search_query_input, search_type_input, search_button, items_output, item_mapping, api_name_input, api_key_input, chunking_options_checkbox, chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + return ( + search_query_input, search_type_input, search_button, items_output, + item_mapping, api_name_input, api_key_input, chunking_options_checkbox, + chunking_options_box, chunk_method, max_chunk_size, chunk_overlap, + custom_prompt_checkbox, custom_prompt_input, resummarize_button, result_output + ) def update_resummarize_dropdown(search_query, search_type): diff --git a/App_Function_Libraries/Gradio_UI/Search_Tab.py b/App_Function_Libraries/Gradio_UI/Search_Tab.py index 3b50ac1c1ffb98d495079377fb1e0c1b215e41ec..151879043d9048cf5f832978370ec143dc2642b8 100644 --- a/App_Function_Libraries/Gradio_UI/Search_Tab.py +++ b/App_Function_Libraries/Gradio_UI/Search_Tab.py @@ -11,8 +11,8 @@ import gradio as gr # # Local Imports from App_Function_Libraries.DB.DB_Manager import view_database, search_and_display_items, get_all_document_versions, \ - fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription -from App_Function_Libraries.DB.SQLite_DB import search_prompts, get_document_version + fetch_item_details_single, fetch_paginated_data, fetch_item_details, get_latest_transcription, search_prompts, \ + get_document_version from App_Function_Libraries.Gradio_UI.Gradio_Shared import update_dropdown, update_detailed_view from App_Function_Libraries.Utils.Utils import get_database_path, format_text_with_line_breaks # @@ -80,8 +80,8 @@ def format_as_html(content, title): """ def create_search_tab(): - with gr.TabItem("Search / Detailed View", visible=True): - gr.Markdown("# Search across all ingested items in the Database") + with gr.TabItem("Media DB Search / Detailed View", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database") with gr.Row(): with gr.Column(scale=1): gr.Markdown("by Title / URL / Keyword / or Content via SQLite Full-Text-Search") @@ -150,8 +150,8 @@ def display_search_results(query): def create_search_summaries_tab(): - with gr.TabItem("Search/View Title+Summary", visible=True): - gr.Markdown("# Search across all ingested items in the Database and review their summaries") + with gr.TabItem("Media DB Search/View Title+Summary", visible=True): + gr.Markdown("# Search across all ingested items in the Media Database and review their summaries") gr.Markdown("Search by Title / URL / Keyword / or Content via SQLite Full-Text-Search") with gr.Row(): with gr.Column(): diff --git a/App_Function_Libraries/Gradio_UI/Semantic_Scholar_tab.py b/App_Function_Libraries/Gradio_UI/Semantic_Scholar_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..dd142397dfc72054170f022bb9da3277158f9b4f --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Semantic_Scholar_tab.py @@ -0,0 +1,184 @@ +# Sematnic_Scholar_tab.py +# Description: contains the code to create the Semantic Scholar tab in the Gradio UI. +# +# Imports +# +# External Libraries +import gradio as gr +# +# Internal Libraries +from App_Function_Libraries.Third_Party.Semantic_Scholar import search_and_display, FIELDS_OF_STUDY, PUBLICATION_TYPES + + +# +###################################################################################################################### +# Functions +def create_semantic_scholar_tab(): + """Create the Semantic Scholar tab for the Gradio UI""" + with gr.Tab("Semantic Scholar Search"): + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown(""" + ## Semantic Scholar Paper Search + + This interface allows you to search for academic papers using the Semantic Scholar API with advanced filtering options: + + ### Search Options + - **Keywords**: Search across titles, abstracts, and other paper content + - **Year Range**: Filter papers by publication year (e.g., "2020-2023" or "2020") + - **Venue**: Filter by publication venue (journal or conference) + - **Minimum Citations**: Filter papers by minimum citation count + - **Fields of Study**: Filter papers by academic field + - **Publication Types**: Filter by type of publication + - **Open Access**: Option to show only papers with free PDF access + + ### Results Include + - Paper title + - Author list + - Publication year and venue + - Citation count + - Publication types + - Abstract + - Links to PDF (when available) and Semantic Scholar page + """) + with gr.Column(scale=2): + gr.Markdown(""" + ### Pagination + - 10 results per page + - Navigate through results using Previous/Next buttons + - Current page number and total results displayed + + ### Usage Tips + - Combine multiple filters for more specific results + - Use specific terms for more focused results + - Try different combinations of filters if you don't find what you're looking for + """) + with gr.Row(): + with gr.Column(scale=2): + search_input = gr.Textbox( + label="Search Query", + placeholder="Enter keywords to search for papers...", + lines=1 + ) + + # Advanced search options + with gr.Row(): + year_range = gr.Textbox( + label="Year Range", + placeholder="e.g., 2020-2023 or 2020", + lines=1 + ) + venue = gr.Textbox( + label="Venue", + placeholder="e.g., Nature, Science", + lines=1 + ) + min_citations = gr.Number( + label="Minimum Citations", + value=0, + minimum=0, + step=1 + ) + + with gr.Row(): + fields_of_study = gr.Dropdown( + choices=FIELDS_OF_STUDY, + label="Fields of Study", + multiselect=True, + value=[] + ) + publication_types = gr.Dropdown( + choices=PUBLICATION_TYPES, + label="Publication Types", + multiselect=True, + value=[] + ) + + open_access_only = gr.Checkbox( + label="Open Access Only", + value=False + ) + + with gr.Column(scale=1): + search_button = gr.Button("Search", variant="primary") + + # Pagination controls + with gr.Row(): + prev_button = gr.Button("← Previous") + current_page = gr.Number(value=0, label="Page", minimum=0, step=1) + max_page = gr.Number(value=0, label="Max Page", visible=False) + next_button = gr.Button("Next →") + + total_results = gr.Textbox( + label="Total Results", + value="0", + interactive=False + ) + + output_text = gr.Markdown( + label="Results", + value="Use the search options above to find papers." + ) + + def update_page(direction, current, maximum): + new_page = current + direction + if new_page < 0: + return 0 + if new_page > maximum: + return maximum + return new_page + + # Handle search and pagination + def search_from_button(query, fields_of_study, publication_types, year_range, venue, min_citations, + open_access_only): + """Wrapper to always search from page 0 when search button is clicked""" + return search_and_display( + query=query, + page=0, # Force page 0 for new searches + fields_of_study=fields_of_study, + publication_types=publication_types, + year_range=year_range, + venue=venue, + min_citations=min_citations, + open_access_only=open_access_only + ) + normal_search = search_and_display + + search_button.click( + fn=search_from_button, + inputs=[ + search_input, fields_of_study, publication_types, + year_range, venue, min_citations, open_access_only + ], + outputs=[output_text, current_page, max_page, total_results] + ) + + prev_button.click( + fn=lambda curr, max_p: update_page(-1, curr, max_p), + inputs=[current_page, max_page], + outputs=current_page + ).then( + fn=normal_search, + inputs=[ + search_input, current_page, fields_of_study, publication_types, + year_range, venue, min_citations, open_access_only + ], + outputs=[output_text, current_page, max_page, total_results] + ) + + next_button.click( + fn=lambda curr, max_p: update_page(1, curr, max_p), + inputs=[current_page, max_page], + outputs=current_page + ).then( + fn=normal_search, + inputs=[ + search_input, current_page, fields_of_study, publication_types, + year_range, venue, min_citations, open_access_only + ], + outputs=[output_text, current_page, max_page, total_results] + ) + +# +# End of Semantic_Scholar_tab.py +###################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py index d27f5b84dd3a1ac20add9a8aa35c773fb8b4a0f1..3db1706d92716e8dca425b94ffd59eaac99c24ab 100644 --- a/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py +++ b/App_Function_Libraries/Gradio_UI/Video_transcription_tab.py @@ -6,22 +6,23 @@ import json import logging import os from datetime import datetime -from typing import Dict, Any - # # External Imports import gradio as gr import yt_dlp + +from App_Function_Libraries.Chunk_Lib import improved_chunking_process # # Local Imports -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts, add_media_to_database, \ - check_media_and_whisper_model, check_existing_media, update_media_content_with_version +from App_Function_Libraries.DB.DB_Manager import add_media_to_database, \ + check_media_and_whisper_model, check_existing_media, update_media_content_with_version, list_prompts from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models, update_user_prompt from App_Function_Libraries.Gradio_UI.Gradio_Shared import error_handler from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_transcription, perform_summarization, \ save_transcription_and_summary from App_Function_Libraries.Utils.Utils import convert_to_seconds, safe_read_file, format_transcription, \ - create_download_directory, generate_unique_identifier, extract_text_from_segments + create_download_directory, generate_unique_identifier, extract_text_from_segments, default_api_endpoint, \ + global_api_endpoints, format_api_name from App_Function_Libraries.Video_DL_Ingestion_Lib import parse_and_expand_urls, extract_metadata, download_video from App_Function_Libraries.Benchmarks_Evaluations.ms_g_eval import run_geval # Import metrics logging @@ -32,6 +33,16 @@ from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histo # Functions: def create_video_transcription_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.TabItem("Video Transcription + Summarization", visible=True): gr.Markdown("# Transcribe & Summarize Videos from URLs") with gr.Row(): @@ -56,15 +67,20 @@ def create_video_transcription_tab(): preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): + # Add pagination controls preset_prompt = gr.Dropdown(label="Select Preset Prompt", - choices=load_preset_prompts(), + choices=[], visible=False) with gr.Row(): - custom_prompt_input = gr.Textbox(label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False) + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) with gr.Row(): system_prompt_input = gr.Textbox(label="System Prompt", value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] @@ -87,22 +103,75 @@ def create_video_transcription_tab(): lines=3, visible=False, interactive=True) + with gr.Row(): + custom_prompt_input = gr.Textbox(label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False) + custom_prompt_checkbox.change( - fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), + fn=lambda x: (gr.update(visible=x, interactive=x), gr.update(visible=x, interactive=x)), inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), + fn=on_preset_prompt_checkbox_change, inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] ) def update_prompts(preset_name): prompts = update_user_prompt(preset_name) return ( - gr.update(value=prompts["user_prompt"], visible=True), - gr.update(value=prompts["system_prompt"], visible=True) + gr.update(value=prompts["user_prompt"], visible=True, interactive=True), + gr.update(value=prompts["system_prompt"], visible=True, interactive=True) ) preset_prompt.change( @@ -111,11 +180,12 @@ def create_video_transcription_tab(): outputs=[custom_prompt_input, system_prompt_input] ) + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", "Custom-OpenAI-API"], - value=None, label="API Name (Mandatory)") + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" + ) api_key_input = gr.Textbox(label="API Key (Optional - Set in Config.txt)", placeholder="Enter your API key here", type="password") keywords_input = gr.Textbox(label="Keywords", placeholder="Enter keywords here (comma-separated)", @@ -198,8 +268,7 @@ def create_video_transcription_tab(): progress: gr.Progress = gr.Progress()) -> tuple: try: # Start overall processing timer - proc_start_time = datetime.utcnow() - # FIXME - summarize_recursively is not being used... + proc_start_time = datetime.now() logging.info("Entering process_videos_with_error_handling") logging.info(f"Received inputs: {inputs}") @@ -251,8 +320,7 @@ def create_video_transcription_tab(): all_summaries = "" # Start timing - # FIXME - utcnow() is deprecated and scheduled for removal in a future version. Use timezone-aware objects to represent datetimes in UTC: datetime.datetime.now(datetime.UTC). - start_proc = datetime.utcnow() + start_proc = datetime.now() for i in range(0, len(all_inputs), batch_size): batch = all_inputs[i:i + batch_size] @@ -260,7 +328,7 @@ def create_video_transcription_tab(): for input_item in batch: # Start individual video processing timer - video_start_time = datetime.utcnow() + video_start_time = datetime.now() try: start_seconds = convert_to_seconds(start_time) end_seconds = convert_to_seconds(end_time) if end_time else None @@ -313,7 +381,7 @@ def create_video_transcription_tab(): input_item, 2, whisper_model, custom_prompt, start_seconds, api_name, api_key, - vad_use, False, False, False, 0.01, None, keywords, None, diarize, + vad_use, False, False, summarize_recursively, 0.01, None, keywords, None, diarize, end_time=end_seconds, include_timestamps=timestamp_option, metadata=video_metadata, @@ -365,7 +433,7 @@ def create_video_transcription_tab(): ) # Calculate processing time - video_end_time = datetime.utcnow() + video_end_time = datetime.now() processing_time = (video_end_time - video_start_time).total_seconds() log_histogram( metric_name="video_processing_time_seconds", @@ -473,7 +541,7 @@ def create_video_transcription_tab(): total_inputs = len(all_inputs) # End overall processing timer - proc_end_time = datetime.utcnow() + proc_end_time = datetime.now() total_processing_time = (proc_end_time - proc_start_time).total_seconds() log_histogram( metric_name="total_processing_time_seconds", @@ -702,8 +770,9 @@ def create_video_transcription_tab(): # Perform transcription logging.info("process_url_with_metadata: Starting transcription...") + logging.info(f"process_url_with_metadata: overwrite existing?: {overwrite_existing}") audio_file_path, segments = perform_transcription(video_file_path, offset, whisper_model, - vad_filter, diarize) + vad_filter, diarize, overwrite_existing) if audio_file_path is None or segments is None: logging.error("process_url_with_metadata: Transcription failed or segments not available.") @@ -771,7 +840,54 @@ def create_video_transcription_tab(): # API key resolution handled at base of function if none provided api_key = api_key if api_key else None logging.info(f"process_url_with_metadata: Starting summarization with {api_name}...") - summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, api_key) + + # Perform Chunking if enabled + # FIXME - Setup a proper prompt for Recursive Summarization + if use_chunking: + logging.info("process_url_with_metadata: Chunking enabled. Starting chunking...") + chunked_texts = improved_chunking_process(full_text_with_metadata, chunk_options) + + if chunked_texts is None: + logging.warning("Chunking failed, falling back to full text summarization") + summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, + api_key) + else: + logging.debug( + f"process_url_with_metadata: Chunking completed. Processing {len(chunked_texts)} chunks...") + summaries = [] + + if rolling_summarization: + # Perform recursive summarization on each chunk + for chunk in chunked_texts: + chunk_summary = perform_summarization(api_name, chunk['text'], custom_prompt, + api_key) + if chunk_summary: + summaries.append( + f"Chunk {chunk['metadata']['chunk_index']}/{chunk['metadata']['total_chunks']}: {chunk_summary}") + summary_text = "\n\n".join(summaries) + else: + logging.error("All chunk summarizations failed") + summary_text = None + + for chunk in chunked_texts: + # Perform Non-recursive summarization on each chunk + chunk_summary = perform_summarization(api_name, chunk['text'], custom_prompt, + api_key) + if chunk_summary: + summaries.append( + f"Chunk {chunk['metadata']['chunk_index']}/{chunk['metadata']['total_chunks']}: {chunk_summary}") + + if summaries: + summary_text = "\n\n".join(summaries) + logging.info(f"Successfully summarized {len(summaries)} chunks") + else: + logging.error("All chunk summarizations failed") + summary_text = None + else: + # Regular summarization without chunking + summary_text = perform_summarization(api_name, full_text_with_metadata, custom_prompt, + api_key) if api_name else None + if summary_text is None: logging.error("Summarization failed.") return None, None, None, None, None, None @@ -859,3 +975,7 @@ def create_video_transcription_tab(): ], outputs=[progress_output, error_output, results_output, download_transcription, download_summary, confabulation_output] ) + +# +# End of Video_transcription_tab.py +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py b/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py index 16a3fa9b34ec7c6abba38cd2b9c7606943231f98..5bbceb0571b2d3f78e39e7ec3fc383baa2608a5b 100644 --- a/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py +++ b/App_Function_Libraries/Gradio_UI/View_DB_Items_tab.py @@ -3,131 +3,26 @@ # # Imports import html +import logging + # # External Imports import gradio as gr # # Local Imports from App_Function_Libraries.DB.DB_Manager import view_database, get_all_document_versions, \ - fetch_paginated_data, fetch_item_details, get_latest_transcription, list_prompts, fetch_prompt_details, \ - load_preset_prompts -from App_Function_Libraries.DB.SQLite_DB import get_document_version + fetch_paginated_data, fetch_item_details, get_latest_transcription, list_prompts, fetch_prompt_details +from App_Function_Libraries.DB.RAG_QA_Chat_DB import get_keywords_for_note, search_conversations_by_keywords, \ + get_notes_by_keywords, get_keywords_for_conversation, get_db_connection, get_all_conversations, load_chat_history, \ + get_notes +from App_Function_Libraries.DB.SQLite_DB import get_document_version, fetch_items_by_keyword, fetch_all_keywords + + # #################################################################################################### # # Functions -def create_prompt_view_tab(): - with gr.TabItem("View Prompt Database", visible=True): - gr.Markdown("# View Prompt Database Entries") - with gr.Row(): - with gr.Column(): - entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) - page_number = gr.Number(value=1, label="Page Number", precision=0) - view_button = gr.Button("View Page") - next_page_button = gr.Button("Next Page") - previous_page_button = gr.Button("Previous Page") - pagination_info = gr.Textbox(label="Pagination Info", interactive=False) - prompt_selector = gr.Dropdown(label="Select Prompt to View", choices=[]) - with gr.Column(): - results_table = gr.HTML() - selected_prompt_display = gr.HTML() - - def view_database(page, entries_per_page): - try: - prompts, total_pages, current_page = list_prompts(page, entries_per_page) - - table_html = "" - table_html += "" - prompt_choices = [] - for prompt_name in prompts: - details = fetch_prompt_details(prompt_name) - if details: - title, _, _, _, _, _ = details - author = "Unknown" # Assuming author is not stored in the current schema - table_html += f"" - prompt_choices.append((title, title)) # Using title as both label and value - table_html += "
TitleAuthor
{html.escape(title)}{html.escape(author)}
" - - total_prompts = len(load_preset_prompts()) # This might be inefficient for large datasets - pagination = f"Page {current_page} of {total_pages} (Total prompts: {total_prompts})" - - return table_html, pagination, total_pages, prompt_choices - except Exception as e: - return f"

Error fetching prompts: {e}

", "Error", 0, [] - - def update_page(page, entries_per_page): - results, pagination, total_pages, prompt_choices = view_database(page, entries_per_page) - next_disabled = page >= total_pages - prev_disabled = page <= 1 - return results, pagination, page, gr.update(interactive=not next_disabled), gr.update( - interactive=not prev_disabled), gr.update(choices=prompt_choices) - - def go_to_next_page(current_page, entries_per_page): - next_page = current_page + 1 - return update_page(next_page, entries_per_page) - - def go_to_previous_page(current_page, entries_per_page): - previous_page = max(1, current_page - 1) - return update_page(previous_page, entries_per_page) - - def display_selected_prompt(prompt_name): - details = fetch_prompt_details(prompt_name) - if details: - title, author, description, system_prompt, user_prompt, keywords = details - # Handle None values by converting them to empty strings - description = description or "" - system_prompt = system_prompt or "" - user_prompt = user_prompt or "" - author = author or "Unknown" - keywords = keywords or "" - - html_content = f""" -
-

{html.escape(title)}

by {html.escape(author)}

-

Description: {html.escape(description)}

-
- System Prompt: -
{html.escape(system_prompt)}
-
-
- User Prompt: -
{html.escape(user_prompt)}
-
-

Keywords: {html.escape(keywords)}

-
- """ - return html_content - else: - return "

Prompt not found.

" - - view_button.click( - fn=update_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - next_page_button.click( - fn=go_to_next_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - previous_page_button.click( - fn=go_to_previous_page, - inputs=[page_number, entries_per_page], - outputs=[results_table, pagination_info, page_number, next_page_button, previous_page_button, - prompt_selector] - ) - - prompt_selector.change( - fn=display_selected_prompt, - inputs=[prompt_selector], - outputs=[selected_prompt_display] - ) - def format_as_html(content, title): escaped_content = html.escape(content) formatted_content = escaped_content.replace('\n', '
') @@ -149,9 +44,9 @@ def extract_prompt_and_summary(content: str): return prompt, summary -def create_view_all_with_versions_tab(): - with gr.TabItem("View All Items", visible=True): - gr.Markdown("# View All Database Entries with Version Selection") +def create_view_all_mediadb_with_versions_tab(): + with gr.TabItem("View All MediaDB Items", visible=True): + gr.Markdown("# View All Media Database Entries with Version Selection") with gr.Row(): with gr.Column(scale=1): entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) @@ -280,9 +175,143 @@ def create_view_all_with_versions_tab(): ) -def create_viewing_tab(): - with gr.TabItem("View Database Entries", visible=True): - gr.Markdown("# View Database Entries") +def create_mediadb_keyword_search_tab(): + with gr.TabItem("Search MediaDB by Keyword", visible=True): + gr.Markdown("# List Media Database Items by Keyword") + + with gr.Row(): + with gr.Column(scale=1): + # Keyword selection dropdown - initialize with empty list, will be populated on load + keyword_dropdown = gr.Dropdown( + label="Select Keyword", + choices=fetch_all_keywords(), # Initialize with keywords on creation + value=None + ) + entries_per_page = gr.Dropdown( + choices=[10, 20, 50, 100], + label="Entries per Page", + value=10 + ) + page_number = gr.Number( + value=1, + label="Page Number", + precision=0 + ) + + # Navigation buttons + refresh_keywords_button = gr.Button("Refresh Keywords") + view_button = gr.Button("View Results") + next_page_button = gr.Button("Next Page") + previous_page_button = gr.Button("Previous Page") + + # Pagination information + pagination_info = gr.Textbox( + label="Pagination Info", + interactive=False + ) + + with gr.Column(scale=2): + # Results area + results_table = gr.HTML( + label="Search Results" + ) + item_details = gr.HTML( + label="Item Details", + visible=True + ) + + def update_keyword_choices(): + try: + keywords = fetch_all_keywords() + return gr.update(choices=keywords) + except Exception as e: + return gr.update(choices=[], value=None) + + def search_items(keyword, page, entries_per_page): + try: + # Calculate offset for pagination + offset = (page - 1) * entries_per_page + + # Fetch items for the selected keyword + items = fetch_items_by_keyword(keyword) + total_items = len(items) + total_pages = (total_items + entries_per_page - 1) // entries_per_page + + # Paginate results + paginated_items = items[offset:offset + entries_per_page] + + # Generate HTML table for results + table_html = "" + table_html += "" + table_html += "" + + for item_id, title, url in paginated_items: + table_html += f""" + + + + + """ + table_html += "
TitleURL
{html.escape(title)}{html.escape(url)}
" + + # Update pagination info + pagination = f"Page {page} of {total_pages} (Total items: {total_items})" + + # Determine button states + next_disabled = page >= total_pages + prev_disabled = page <= 1 + + return ( + table_html, + pagination, + gr.update(interactive=not next_disabled), + gr.update(interactive=not prev_disabled) + ) + except Exception as e: + return ( + f"

Error: {str(e)}

", + "Error in pagination", + gr.update(interactive=False), + gr.update(interactive=False) + ) + + def go_to_next_page(keyword, current_page, entries_per_page): + next_page = current_page + 1 + return search_items(keyword, next_page, entries_per_page) + (next_page,) + + def go_to_previous_page(keyword, current_page, entries_per_page): + previous_page = max(1, current_page - 1) + return search_items(keyword, previous_page, entries_per_page) + (previous_page,) + + # Event handlers + refresh_keywords_button.click( + fn=update_keyword_choices, + inputs=[], + outputs=[keyword_dropdown] + ) + + view_button.click( + fn=search_items, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[results_table, pagination_info, next_page_button, previous_page_button] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[results_table, pagination_info, next_page_button, previous_page_button, page_number] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[results_table, pagination_info, next_page_button, previous_page_button, page_number] + ) + + +def create_viewing_mediadb_tab(): + with gr.TabItem("View Media Database Entries", visible=True): + gr.Markdown("# View Media Database Entries") with gr.Row(): with gr.Column(): entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) @@ -327,5 +356,461 @@ def create_viewing_tab(): outputs=[results_display, pagination_info, page_number, next_page_button, previous_page_button] ) +##################################################################### +# +# RAG DB Viewing Functions: + +def create_viewing_ragdb_tab(): + with gr.TabItem("View RAG Database Entries", visible=True): + gr.Markdown("# View RAG Database Entries") + with gr.Row(): + with gr.Column(): + entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) + page_number = gr.Number(value=1, label="Page Number", precision=0) + view_button = gr.Button("View Page") + next_page_button = gr.Button("Next Page") + previous_page_button = gr.Button("Previous Page") + pagination_info = gr.Textbox(label="Pagination Info", interactive=False) + with gr.Column(): + results_display = gr.HTML() + + def format_conversations_table(conversations): + table_html = "" + table_html += """ + + + + + + + """ + + for conversation in conversations: + conv_id = conversation['conversation_id'] + title = conversation['title'] + rating = conversation.get('rating', '') # Use get() to handle cases where rating might not exist + + keywords = get_keywords_for_conversation(conv_id) + notes = get_notes(conv_id) + + table_html += f""" + + + + + + + """ + table_html += "
TitleKeywordsNotesRating
{html.escape(str(title))}{html.escape(', '.join(keywords))}{len(notes)} note(s){html.escape(str(rating))}
" + return table_html + + def update_page(page, entries_per_page): + try: + conversations, total_pages, total_count = get_all_conversations(page, entries_per_page) + results_html = format_conversations_table(conversations) + pagination = f"Page {page} of {total_pages} (Total conversations: {total_count})" + + next_disabled = page >= total_pages + prev_disabled = page <= 1 + + return ( + results_html, + pagination, + page, + gr.update(interactive=not next_disabled), + gr.update(interactive=not prev_disabled) + ) + except Exception as e: + return ( + f"

Error: {str(e)}

", + "Error in pagination", + page, + gr.update(interactive=False), + gr.update(interactive=False) + ) + + def go_to_next_page(current_page, entries_per_page): + return update_page(current_page + 1, entries_per_page) + + def go_to_previous_page(current_page, entries_per_page): + return update_page(max(1, current_page - 1), entries_per_page) + + view_button.click( + fn=update_page, + inputs=[page_number, entries_per_page], + outputs=[results_display, pagination_info, page_number, next_page_button, previous_page_button] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[page_number, entries_per_page], + outputs=[results_display, pagination_info, page_number, next_page_button, previous_page_button] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[page_number, entries_per_page], + outputs=[results_display, pagination_info, page_number, next_page_button, previous_page_button] + ) + + +def create_view_all_rag_notes_tab(): + with gr.TabItem("View All RAG notes/Conversation Items", visible=True): + gr.Markdown("# View All RAG Notes/Conversation Entries") + with gr.Row(): + with gr.Column(scale=1): + entries_per_page = gr.Dropdown(choices=[10, 20, 50, 100], label="Entries per Page", value=10) + page_number = gr.Number(value=1, label="Page Number", precision=0) + view_button = gr.Button("View Page") + next_page_button = gr.Button("Next Page") + previous_page_button = gr.Button("Previous Page") + with gr.Column(scale=2): + items_output = gr.Dropdown(label="Select Conversation to View Details", choices=[]) + conversation_title = gr.Textbox(label="Conversation Title", visible=True) + with gr.Row(): + with gr.Column(scale=1): + pagination_info = gr.Textbox(label="Pagination Info", interactive=False) + with gr.Column(scale=2): + keywords_output = gr.Textbox(label="Keywords", visible=True) + chat_history_output = gr.HTML(label="Chat History", visible=True) + notes_output = gr.HTML(label="Associated Notes", visible=True) + + item_mapping = gr.State({}) + + def update_page(page, entries_per_page): + try: + conversations, total_pages, total_count = get_all_conversations(page, entries_per_page) + pagination = f"Page {page} of {total_pages} (Total conversations: {total_count})" + + # Handle the dictionary structure correctly + choices = [f"{conv['title']} (ID: {conv['conversation_id']})" for conv in conversations] + new_item_mapping = { + f"{conv['title']} (ID: {conv['conversation_id']})": conv['conversation_id'] + for conv in conversations + } + + next_disabled = page >= total_pages + prev_disabled = page <= 1 + + return ( + gr.update(choices=choices, value=None), + pagination, + page, + gr.update(interactive=not next_disabled), + gr.update(interactive=not prev_disabled), + "", # conversation_title + "", # keywords_output + "", # chat_history_output + "", # notes_output + new_item_mapping + ) + except Exception as e: + logging.error(f"Error in update_page: {str(e)}", exc_info=True) + return ( + gr.update(choices=[], value=None), + f"Error: {str(e)}", + page, + gr.update(interactive=False), + gr.update(interactive=False), + "", "", "", "", + {} + ) + + def format_as_html(content, title): + if content is None: + content = "No content available." + escaped_content = html.escape(str(content)) + formatted_content = escaped_content.replace('\n', '
') + return f""" +
+

{title}

+
+ {formatted_content} +
+
+ """ + + def format_chat_history(messages): + html_content = "
" + for role, content in messages: + role_class = "assistant" if role.lower() == "assistant" else "user" + html_content += f""" +
+ {html.escape(role)}:
+ {html.escape(content)} +
+ """ + html_content += "
" + return html_content + + def display_conversation_details(selected_item, item_mapping): + if selected_item and item_mapping and selected_item in item_mapping: + conv_id = item_mapping[selected_item] + + # Get keywords + keywords = get_keywords_for_conversation(conv_id) + keywords_text = ", ".join(keywords) if keywords else "No keywords" + + # Get chat history + chat_messages, _, _ = load_chat_history(conv_id) + chat_html = format_chat_history(chat_messages) + + # Get associated notes + notes = get_notes(conv_id) + notes_html = "" + for note in notes: + notes_html += format_as_html(note, "Note") + if not notes: + notes_html = "

No notes associated with this conversation.

" + + return ( + selected_item.split(" (ID:")[0], # Conversation title + keywords_text, + chat_html, + notes_html + ) + return "", "", "", "" + + view_button.click( + fn=update_page, + inputs=[page_number, entries_per_page], + outputs=[ + items_output, + pagination_info, + page_number, + next_page_button, + previous_page_button, + conversation_title, + keywords_output, + chat_history_output, + notes_output, + item_mapping + ] + ) + + next_page_button.click( + fn=lambda page, entries: update_page(page + 1, entries), + inputs=[page_number, entries_per_page], + outputs=[items_output, pagination_info, page_number, next_page_button, previous_page_button, + conversation_title, keywords_output, chat_history_output, notes_output, item_mapping] + ) + + previous_page_button.click( + fn=lambda page, entries: update_page(max(1, page - 1), entries), + inputs=[page_number, entries_per_page], + outputs=[items_output, pagination_info, page_number, next_page_button, previous_page_button, + conversation_title, keywords_output, chat_history_output, notes_output, item_mapping] + ) + + items_output.change( + fn=display_conversation_details, + inputs=[items_output, item_mapping], + outputs=[conversation_title, keywords_output, chat_history_output, notes_output] + ) + + +def create_ragdb_keyword_items_tab(): + with gr.TabItem("View RAG Notes/Conversations by Keyword", visible=True): + gr.Markdown("# View RAG Notes and Conversations by Keyword") + + with gr.Row(): + with gr.Column(scale=1): + # Keyword selection + keyword_dropdown = gr.Dropdown( + label="Select Keyword", + choices=[], + value=None, + multiselect=True + ) + entries_per_page = gr.Dropdown( + choices=[10, 20, 50, 100], + label="Entries per Page", + value=10 + ) + page_number = gr.Number( + value=1, + label="Page Number", + precision=0 + ) + + # Navigation buttons + refresh_keywords_button = gr.Button("Refresh Keywords") + view_button = gr.Button("View Items") + next_page_button = gr.Button("Next Page") + previous_page_button = gr.Button("Previous Page") + pagination_info = gr.Textbox( + label="Pagination Info", + interactive=False + ) + + with gr.Column(scale=2): + # Results tabs for conversations and notes + with gr.Tabs(): + with gr.Tab("Notes"): + notes_results = gr.HTML() + with gr.Tab("Conversations"): + conversation_results = gr.HTML() + + def update_keyword_choices(): + """Fetch all available keywords for the dropdown.""" + try: + query = "SELECT keyword FROM rag_qa_keywords ORDER BY keyword" + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute(query) + keywords = [row[0] for row in cursor.fetchall()] + return gr.update(choices=keywords) + except Exception as e: + return gr.update(choices=[], value=None) + + def format_conversations_html(conversations_data): + """Format conversations data as HTML.""" + if not conversations_data: + return "

No conversations found for selected keywords.

" + + html_content = "
" + for conv_id, title in conversations_data: + html_content += f""" +
+

{html.escape(title)}

+

Conversation ID: {html.escape(conv_id)}

+

Keywords: {', '.join(html.escape(k) for k in get_keywords_for_conversation(conv_id))}

+
+ """ + html_content += "
" + return html_content + + def format_notes_html(notes_data): + """Format notes data as HTML.""" + if not notes_data: + return "

No notes found for selected keywords.

" + + html_content = "
" + for note_id, title, content, timestamp in notes_data: + keywords = get_keywords_for_note(note_id) + html_content += f""" +
+

{html.escape(title)}

+

Created: {timestamp}

+

Keywords: {', '.join(html.escape(k) for k in keywords)}

+
+ {html.escape(content)} +
+
+ """ + html_content += "
" + return html_content + + def view_items(keywords, page, entries_per_page): + if not keywords or (isinstance(keywords, list) and len(keywords) == 0): + return ( + "

Please select at least one keyword.

", + "

Please select at least one keyword.

", + "No results", + gr.update(interactive=False), + gr.update(interactive=False) + ) + + try: + # Ensure keywords is a list + keywords_list = keywords if isinstance(keywords, list) else [keywords] + + # Get conversations for selected keywords + conversations, conv_total_pages, conv_count = search_conversations_by_keywords( + keywords_list, page, entries_per_page + ) + + # Get notes for selected keywords + notes, notes_total_pages, notes_count = get_notes_by_keywords( + keywords_list, page, entries_per_page + ) + + # Format results as HTML + conv_html = format_conversations_html(conversations) + notes_html = format_notes_html(notes) + + # Create pagination info + pagination = f"Page {page} of {max(conv_total_pages, notes_total_pages)} " + pagination += f"(Conversations: {conv_count}, Notes: {notes_count})" + + # Determine button states + max_pages = max(conv_total_pages, notes_total_pages) + next_disabled = page >= max_pages + prev_disabled = page <= 1 + + return ( + conv_html, + notes_html, + pagination, + gr.update(interactive=not next_disabled), + gr.update(interactive=not prev_disabled) + ) + except Exception as e: + logging.error(f"Error in view_items: {str(e)}") + return ( + f"

Error: {str(e)}

", + f"

Error: {str(e)}

", + "Error in retrieval", + gr.update(interactive=False), + gr.update(interactive=False) + ) + + def go_to_next_page(keywords, current_page, entries_per_page): + return view_items(keywords, current_page + 1, entries_per_page) + + def go_to_previous_page(keywords, current_page, entries_per_page): + return view_items(keywords, max(1, current_page - 1), entries_per_page) + + # Event handlers + refresh_keywords_button.click( + fn=update_keyword_choices, + inputs=[], + outputs=[keyword_dropdown] + ) + + view_button.click( + fn=view_items, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[ + conversation_results, + notes_results, + pagination_info, + next_page_button, + previous_page_button + ] + ) + + next_page_button.click( + fn=go_to_next_page, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[ + conversation_results, + notes_results, + pagination_info, + next_page_button, + previous_page_button + ] + ) + + previous_page_button.click( + fn=go_to_previous_page, + inputs=[keyword_dropdown, page_number, entries_per_page], + outputs=[ + conversation_results, + notes_results, + pagination_info, + next_page_button, + previous_page_button + ] + ) + + # Initialize keyword dropdown on page load + keyword_dropdown.value = update_keyword_choices() + +# +# End of RAG DB Viewing tabs +################################################################ + # -#################################################################################################### \ No newline at end of file +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py index 6640943cd72f34f86379411603cc71dce6825062..ecffc238618bb51a4b85af7f653b5f40ca99a698 100644 --- a/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py +++ b/App_Function_Libraries/Gradio_UI/Website_scraping_tab.py @@ -1,554 +1,754 @@ -# Website_scraping_tab.py -# Gradio UI for scraping websites -# -# Imports -import asyncio -import json -import logging -import os -import random -from concurrent.futures import ThreadPoolExecutor -from typing import Optional, List, Dict, Any -from urllib.parse import urlparse, urljoin - -# -# External Imports -import gradio as gr -from playwright.async_api import TimeoutError, async_playwright -from playwright.sync_api import sync_playwright - -# -# Local Imports -from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_from_sitemap, scrape_by_url_level, scrape_article -from App_Function_Libraries.Web_Scraping.Article_Summarization_Lib import scrape_and_summarize_multiple -from App_Function_Libraries.DB.DB_Manager import load_preset_prompts -from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt -from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize - - -# -######################################################################################################################## -# -# Functions: - -def get_url_depth(url: str) -> int: - return len(urlparse(url).path.strip('/').split('/')) - - -def sync_recursive_scrape(url_input, max_pages, max_depth, progress_callback, delay=1.0): - def run_async_scrape(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( - recursive_scrape(url_input, max_pages, max_depth, progress_callback, delay) - ) - - with ThreadPoolExecutor() as executor: - future = executor.submit(run_async_scrape) - return future.result() - - -async def recursive_scrape( - base_url: str, - max_pages: int, - max_depth: int, - progress_callback: callable, - delay: float = 1.0, - resume_file: str = 'scrape_progress.json', - user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" -) -> List[Dict]: - async def save_progress(): - temp_file = resume_file + ".tmp" - with open(temp_file, 'w') as f: - json.dump({ - 'visited': list(visited), - 'to_visit': to_visit, - 'scraped_articles': scraped_articles, - 'pages_scraped': pages_scraped - }, f) - os.replace(temp_file, resume_file) # Atomic replace - - def is_valid_url(url: str) -> bool: - return url.startswith("http") and len(url) > 0 - - # Load progress if resume file exists - if os.path.exists(resume_file): - with open(resume_file, 'r') as f: - progress_data = json.load(f) - visited = set(progress_data['visited']) - to_visit = progress_data['to_visit'] - scraped_articles = progress_data['scraped_articles'] - pages_scraped = progress_data['pages_scraped'] - else: - visited = set() - to_visit = [(base_url, 0)] # (url, depth) - scraped_articles = [] - pages_scraped = 0 - - try: - async with async_playwright() as p: - browser = await p.chromium.launch(headless=True) - context = await browser.new_context(user_agent=user_agent) - - try: - while to_visit and pages_scraped < max_pages: - current_url, current_depth = to_visit.pop(0) - - if current_url in visited or current_depth > max_depth: - continue - - visited.add(current_url) - - # Update progress - progress_callback(f"Scraping page {pages_scraped + 1}/{max_pages}: {current_url}") - - try: - await asyncio.sleep(random.uniform(delay * 0.8, delay * 1.2)) - - # This function should be implemented to handle asynchronous scraping - article_data = await scrape_article_async(context, current_url) - - if article_data and article_data['extraction_successful']: - scraped_articles.append(article_data) - pages_scraped += 1 - - # If we haven't reached max depth, add child links to to_visit - if current_depth < max_depth: - page = await context.new_page() - await page.goto(current_url) - await page.wait_for_load_state("networkidle") - - links = await page.eval_on_selector_all('a[href]', - "(elements) => elements.map(el => el.href)") - for link in links: - child_url = urljoin(base_url, link) - if is_valid_url(child_url) and child_url.startswith( - base_url) and child_url not in visited and should_scrape_url(child_url): - to_visit.append((child_url, current_depth + 1)) - - await page.close() - - except Exception as e: - logging.error(f"Error scraping {current_url}: {str(e)}") - - # Save progress periodically (e.g., every 10 pages) - if pages_scraped % 10 == 0: - await save_progress() - - finally: - await browser.close() - - finally: - # These statements are now guaranteed to be reached after the scraping is done - await save_progress() - - # Remove the progress file when scraping is completed successfully - if os.path.exists(resume_file): - os.remove(resume_file) - - # Final progress update - progress_callback(f"Scraping completed. Total pages scraped: {pages_scraped}") - - return scraped_articles - - -async def scrape_article_async(context, url: str) -> Dict[str, Any]: - page = await context.new_page() - try: - await page.goto(url) - await page.wait_for_load_state("networkidle") - - title = await page.title() - content = await page.content() - - return { - 'url': url, - 'title': title, - 'content': content, - 'extraction_successful': True - } - except Exception as e: - logging.error(f"Error scraping article {url}: {str(e)}") - return { - 'url': url, - 'extraction_successful': False, - 'error': str(e) - } - finally: - await page.close() - - -def scrape_article_sync(url: str) -> Dict[str, Any]: - with sync_playwright() as p: - browser = p.chromium.launch(headless=True) - page = browser.new_page() - try: - page.goto(url) - page.wait_for_load_state("networkidle") - - title = page.title() - content = page.content() - - return { - 'url': url, - 'title': title, - 'content': content, - 'extraction_successful': True - } - except Exception as e: - logging.error(f"Error scraping article {url}: {str(e)}") - return { - 'url': url, - 'extraction_successful': False, - 'error': str(e) - } - finally: - browser.close() - - -def should_scrape_url(url: str) -> bool: - parsed_url = urlparse(url) - path = parsed_url.path.lower() - - # List of patterns to exclude - exclude_patterns = [ - '/tag/', '/category/', '/author/', '/search/', '/page/', - 'wp-content', 'wp-includes', 'wp-json', 'wp-admin', - 'login', 'register', 'cart', 'checkout', 'account', - '.jpg', '.png', '.gif', '.pdf', '.zip' - ] - - # Check if the URL contains any exclude patterns - if any(pattern in path for pattern in exclude_patterns): - return False - - # Add more sophisticated checks here - # For example, you might want to only include URLs with certain patterns - include_patterns = ['/article/', '/post/', '/blog/'] - if any(pattern in path for pattern in include_patterns): - return True - - # By default, return True if no exclusion or inclusion rules matched - return True - - -async def scrape_with_retry(url: str, max_retries: int = 3, retry_delay: float = 5.0): - for attempt in range(max_retries): - try: - return await scrape_article(url) - except TimeoutError: - if attempt < max_retries - 1: - logging.warning(f"Timeout error scraping {url}. Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) - else: - logging.error(f"Failed to scrape {url} after {max_retries} attempts.") - return None - except Exception as e: - logging.error(f"Error scraping {url}: {str(e)}") - return None - - -def create_website_scraping_tab(): - with gr.TabItem("Website Scraping", visible=True): - gr.Markdown("# Scrape Websites & Summarize Articles") - with gr.Row(): - with gr.Column(): - scrape_method = gr.Radio( - ["Individual URLs", "Sitemap", "URL Level", "Recursive Scraping"], - label="Scraping Method", - value="Individual URLs" - ) - url_input = gr.Textbox( - label="Article URLs or Base URL", - placeholder="Enter article URLs here, one per line, or base URL for sitemap/URL level/recursive scraping", - lines=5 - ) - url_level = gr.Slider( - minimum=1, - maximum=10, - step=1, - label="URL Level (for URL Level scraping)", - value=2, - visible=False - ) - max_pages = gr.Slider( - minimum=1, - maximum=100, - step=1, - label="Maximum Pages to Scrape (for Recursive Scraping)", - value=10, - visible=False - ) - max_depth = gr.Slider( - minimum=1, - maximum=10, - step=1, - label="Maximum Depth (for Recursive Scraping)", - value=3, - visible=False - ) - custom_article_title_input = gr.Textbox( - label="Custom Article Titles (Optional, one per line)", - placeholder="Enter custom titles for the articles, one per line", - lines=5 - ) - with gr.Row(): - summarize_checkbox = gr.Checkbox(label="Summarize Articles", value=False) - custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", value=False, visible=True) - preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) - with gr.Row(): - temp_slider = gr.Slider(0.1, 2.0, 0.7, label="Temperature") - with gr.Row(): - preset_prompt = gr.Dropdown( - label="Select Preset Prompt", - choices=load_preset_prompts(), - visible=False - ) - with gr.Row(): - website_custom_prompt_input = gr.Textbox( - label="Custom Prompt", - placeholder="Enter custom prompt here", - lines=3, - visible=False - ) - with gr.Row(): - system_prompt_input = gr.Textbox( - label="System Prompt", - value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] - **Bulleted Note Creation Guidelines** - - **Headings**: - - Based on referenced topics, not categories like quotes or terms - - Surrounded by **bold** formatting - - Not listed as bullet points - - No space between headings and list items underneath - - **Emphasis**: - - **Important terms** set in bold font - - **Text ending in a colon**: also bolded - - **Review**: - - Ensure adherence to specified format - - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] - """, - lines=3, - visible=False - ) - - api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", - "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "ollama", "HuggingFace", - "Custom-OpenAI-API"], - value=None, - label="API Name (Mandatory for Summarization)" - ) - api_key_input = gr.Textbox( - label="API Key (Mandatory if API Name is specified)", - placeholder="Enter your API key here; Ignore if using Local API or Built-in API", - type="password" - ) - keywords_input = gr.Textbox( - label="Keywords", - placeholder="Enter keywords here (comma-separated)", - value="default,no_keyword_set", - visible=True - ) - - scrape_button = gr.Button("Scrape and Summarize") - with gr.Column(): - progress_output = gr.Textbox(label="Progress", lines=3) - result_output = gr.Textbox(label="Result", lines=20) - - def update_ui_for_scrape_method(method): - url_level_update = gr.update(visible=(method == "URL Level")) - max_pages_update = gr.update(visible=(method == "Recursive Scraping")) - max_depth_update = gr.update(visible=(method == "Recursive Scraping")) - url_input_update = gr.update( - label="Article URLs" if method == "Individual URLs" else "Base URL", - placeholder="Enter article URLs here, one per line" if method == "Individual URLs" else "Enter the base URL for scraping" - ) - return url_level_update, max_pages_update, max_depth_update, url_input_update - - scrape_method.change( - fn=update_ui_for_scrape_method, - inputs=[scrape_method], - outputs=[url_level, max_pages, max_depth, url_input] - ) - - custom_prompt_checkbox.change( - fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), - inputs=[custom_prompt_checkbox], - outputs=[website_custom_prompt_input, system_prompt_input] - ) - preset_prompt_checkbox.change( - fn=lambda x: gr.update(visible=x), - inputs=[preset_prompt_checkbox], - outputs=[preset_prompt] - ) - - def update_prompts(preset_name): - prompts = update_user_prompt(preset_name) - return ( - gr.update(value=prompts["user_prompt"], visible=True), - gr.update(value=prompts["system_prompt"], visible=True) - ) - - preset_prompt.change( - update_prompts, - inputs=preset_prompt, - outputs=[website_custom_prompt_input, system_prompt_input] - ) - - async def scrape_and_summarize_wrapper( - scrape_method: str, - url_input: str, - url_level: Optional[int], - max_pages: int, - max_depth: int, - summarize_checkbox: bool, - custom_prompt: Optional[str], - api_name: Optional[str], - api_key: Optional[str], - keywords: str, - custom_titles: Optional[str], - system_prompt: Optional[str], - temperature: float = 0.7, - progress: gr.Progress = gr.Progress() - ) -> str: - try: - result: List[Dict[str, Any]] = [] - - if scrape_method == "Individual URLs": - result = await scrape_and_summarize_multiple(url_input, custom_prompt, api_name, api_key, keywords, - custom_titles, system_prompt) - elif scrape_method == "Sitemap": - result = await asyncio.to_thread(scrape_from_sitemap, url_input) - elif scrape_method == "URL Level": - if url_level is None: - return convert_json_to_markdown( - json.dumps({"error": "URL level is required for URL Level scraping."})) - result = await asyncio.to_thread(scrape_by_url_level, url_input, url_level) - elif scrape_method == "Recursive Scraping": - result = await recursive_scrape(url_input, max_pages, max_depth, progress.update, delay=1.0) - else: - return convert_json_to_markdown(json.dumps({"error": f"Unknown scraping method: {scrape_method}"})) - - # Ensure result is always a list of dictionaries - if isinstance(result, dict): - result = [result] - elif isinstance(result, list): - if all(isinstance(item, str) for item in result): - # Convert list of strings to list of dictionaries - result = [{"content": item} for item in result] - elif not all(isinstance(item, dict) for item in result): - raise ValueError("Not all items in result are dictionaries or strings") - else: - raise ValueError(f"Unexpected result type: {type(result)}") - - # Ensure all items in result are dictionaries - if not all(isinstance(item, dict) for item in result): - raise ValueError("Not all items in result are dictionaries") - - if summarize_checkbox: - total_articles = len(result) - for i, article in enumerate(result): - progress.update(f"Summarizing article {i + 1}/{total_articles}") - content = article.get('content', '') - if content: - summary = await asyncio.to_thread(summarize, content, custom_prompt, api_name, api_key, - temperature, system_prompt) - article['summary'] = summary - else: - article['summary'] = "No content available to summarize." - - # Concatenate all content - all_content = "\n\n".join( - [f"# {article.get('title', 'Untitled')}\n\n{article.get('content', '')}\n\n" + - (f"Summary: {article.get('summary', '')}" if summarize_checkbox else "") - for article in result]) - - # Collect all unique URLs - all_urls = list(set(article.get('url', '') for article in result if article.get('url'))) - - # Structure the output for the entire website collection - website_collection = { - "base_url": url_input, - "scrape_method": scrape_method, - "summarization_performed": summarize_checkbox, - "api_used": api_name if summarize_checkbox else None, - "keywords": keywords if summarize_checkbox else None, - "url_level": url_level if scrape_method == "URL Level" else None, - "max_pages": max_pages if scrape_method == "Recursive Scraping" else None, - "max_depth": max_depth if scrape_method == "Recursive Scraping" else None, - "total_articles_scraped": len(result), - "urls_scraped": all_urls, - "content": all_content - } - - # Convert the JSON to markdown and return - return convert_json_to_markdown(json.dumps(website_collection, indent=2)) - except Exception as e: - return convert_json_to_markdown(json.dumps({"error": f"An error occurred: {str(e)}"})) - - # Update the scrape_button.click to include the temperature parameter - scrape_button.click( - fn=lambda *args: asyncio.run(scrape_and_summarize_wrapper(*args)), - inputs=[scrape_method, url_input, url_level, max_pages, max_depth, summarize_checkbox, - website_custom_prompt_input, api_name_input, api_key_input, keywords_input, - custom_article_title_input, system_prompt_input, temp_slider], - outputs=[result_output] - ) - - -def convert_json_to_markdown(json_str: str) -> str: - """ - Converts the JSON output from the scraping process into a markdown format. - - Args: - json_str (str): JSON-formatted string containing the website collection data - - Returns: - str: Markdown-formatted string of the website collection data - """ - try: - # Parse the JSON string - data = json.loads(json_str) - - # Check if there's an error in the JSON - if "error" in data: - return f"# Error\n\n{data['error']}" - - # Start building the markdown string - markdown = f"# Website Collection: {data['base_url']}\n\n" - - # Add metadata - markdown += "## Metadata\n\n" - markdown += f"- **Scrape Method:** {data['scrape_method']}\n" - markdown += f"- **API Used:** {data['api_used']}\n" - markdown += f"- **Keywords:** {data['keywords']}\n" - if data['url_level'] is not None: - markdown += f"- **URL Level:** {data['url_level']}\n" - markdown += f"- **Total Articles Scraped:** {data['total_articles_scraped']}\n\n" - - # Add URLs scraped - markdown += "## URLs Scraped\n\n" - for url in data['urls_scraped']: - markdown += f"- {url}\n" - markdown += "\n" - - # Add the content - markdown += "## Content\n\n" - markdown += data['content'] - - return markdown - - except json.JSONDecodeError: - return "# Error\n\nInvalid JSON string provided." - except KeyError as e: - return f"# Error\n\nMissing key in JSON data: {str(e)}" - except Exception as e: - return f"# Error\n\nAn unexpected error occurred: {str(e)}" -# -# End of File -######################################################################################################################## +# Website_scraping_tab.py +# Gradio UI for scraping websites +# +# Imports +import asyncio +import json +import logging +import os +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, List, Dict, Any +from urllib.parse import urlparse, urljoin + +# +# External Imports +import gradio as gr +from playwright.async_api import TimeoutError, async_playwright +from playwright.sync_api import sync_playwright + +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +# +# Local Imports +from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_from_sitemap, scrape_by_url_level, \ + scrape_article, collect_bookmarks, scrape_and_summarize_multiple, collect_urls_from_file +from App_Function_Libraries.DB.DB_Manager import list_prompts +from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt +from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize + + +# +######################################################################################################################## +# +# Functions: + +def get_url_depth(url: str) -> int: + return len(urlparse(url).path.strip('/').split('/')) + + +def sync_recursive_scrape(url_input, max_pages, max_depth, progress_callback, delay=1.0, custom_cookies=None): + def run_async_scrape(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete( + recursive_scrape(url_input, max_pages, max_depth, progress_callback, delay, custom_cookies=custom_cookies) + ) + + with ThreadPoolExecutor() as executor: + future = executor.submit(run_async_scrape) + return future.result() + + +async def recursive_scrape( + base_url: str, + max_pages: int, + max_depth: int, + progress_callback: callable, + delay: float = 1.0, + resume_file: str = 'scrape_progress.json', + user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3", + custom_cookies: Optional[List[Dict[str, Any]]] = None +) -> List[Dict]: + async def save_progress(): + temp_file = resume_file + ".tmp" + with open(temp_file, 'w') as f: + json.dump({ + 'visited': list(visited), + 'to_visit': to_visit, + 'scraped_articles': scraped_articles, + 'pages_scraped': pages_scraped + }, f) + os.replace(temp_file, resume_file) # Atomic replace + + def is_valid_url(url: str) -> bool: + return url.startswith("http") and len(url) > 0 + + # Load progress if resume file exists + if os.path.exists(resume_file): + with open(resume_file, 'r') as f: + progress_data = json.load(f) + visited = set(progress_data['visited']) + to_visit = progress_data['to_visit'] + scraped_articles = progress_data['scraped_articles'] + pages_scraped = progress_data['pages_scraped'] + else: + visited = set() + to_visit = [(base_url, 0)] # (url, depth) + scraped_articles = [] + pages_scraped = 0 + + try: + async with async_playwright() as p: + browser = await p.chromium.launch(headless=True) + context = await browser.new_context(user_agent=user_agent) + + # Set custom cookies if provided + if custom_cookies: + await context.add_cookies(custom_cookies) + + try: + while to_visit and pages_scraped < max_pages: + current_url, current_depth = to_visit.pop(0) + + if current_url in visited or current_depth > max_depth: + continue + + visited.add(current_url) + + # Update progress + progress_callback(f"Scraping page {pages_scraped + 1}/{max_pages}: {current_url}") + + try: + await asyncio.sleep(random.uniform(delay * 0.8, delay * 1.2)) + + # This function should be implemented to handle asynchronous scraping + article_data = await scrape_article_async(context, current_url) + + if article_data and article_data['extraction_successful']: + scraped_articles.append(article_data) + pages_scraped += 1 + + # If we haven't reached max depth, add child links to to_visit + if current_depth < max_depth: + page = await context.new_page() + await page.goto(current_url) + await page.wait_for_load_state("networkidle") + + links = await page.eval_on_selector_all('a[href]', + "(elements) => elements.map(el => el.href)") + for link in links: + child_url = urljoin(base_url, link) + if is_valid_url(child_url) and child_url.startswith( + base_url) and child_url not in visited and should_scrape_url(child_url): + to_visit.append((child_url, current_depth + 1)) + + await page.close() + + except Exception as e: + logging.error(f"Error scraping {current_url}: {str(e)}") + + # Save progress periodically (e.g., every 10 pages) + if pages_scraped % 10 == 0: + await save_progress() + + finally: + await browser.close() + + finally: + # These statements are now guaranteed to be reached after the scraping is done + await save_progress() + + # Remove the progress file when scraping is completed successfully + if os.path.exists(resume_file): + os.remove(resume_file) + + # Final progress update + progress_callback(f"Scraping completed. Total pages scraped: {pages_scraped}") + + return scraped_articles + + +async def scrape_article_async(context, url: str) -> Dict[str, Any]: + page = await context.new_page() + try: + await page.goto(url) + await page.wait_for_load_state("networkidle") + + title = await page.title() + content = await page.content() + + return { + 'url': url, + 'title': title, + 'content': content, + 'extraction_successful': True + } + except Exception as e: + logging.error(f"Error scraping article {url}: {str(e)}") + return { + 'url': url, + 'extraction_successful': False, + 'error': str(e) + } + finally: + await page.close() + + +def scrape_article_sync(url: str) -> Dict[str, Any]: + with sync_playwright() as p: + browser = p.chromium.launch(headless=True) + page = browser.new_page() + try: + page.goto(url) + page.wait_for_load_state("networkidle") + + title = page.title() + content = page.content() + + return { + 'url': url, + 'title': title, + 'content': content, + 'extraction_successful': True + } + except Exception as e: + logging.error(f"Error scraping article {url}: {str(e)}") + return { + 'url': url, + 'extraction_successful': False, + 'error': str(e) + } + finally: + browser.close() + + +def should_scrape_url(url: str) -> bool: + parsed_url = urlparse(url) + path = parsed_url.path.lower() + + # List of patterns to exclude + exclude_patterns = [ + '/tag/', '/category/', '/author/', '/search/', '/page/', + 'wp-content', 'wp-includes', 'wp-json', 'wp-admin', + 'login', 'register', 'cart', 'checkout', 'account', + '.jpg', '.png', '.gif', '.pdf', '.zip' + ] + + # Check if the URL contains any exclude patterns + if any(pattern in path for pattern in exclude_patterns): + return False + + # Add more sophisticated checks here + # For example, you might want to only include URLs with certain patterns + include_patterns = ['/article/', '/post/', '/blog/'] + if any(pattern in path for pattern in include_patterns): + return True + + # By default, return True if no exclusion or inclusion rules matched + return True + + +async def scrape_with_retry(url: str, max_retries: int = 3, retry_delay: float = 5.0): + for attempt in range(max_retries): + try: + return await scrape_article(url) + except TimeoutError: + if attempt < max_retries - 1: + logging.warning(f"Timeout error scraping {url}. Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + else: + logging.error(f"Failed to scrape {url} after {max_retries} attempts.") + return None + except Exception as e: + logging.error(f"Error scraping {url}: {str(e)}") + return None + + +def create_website_scraping_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("Website Scraping", visible=True): + gr.Markdown("# Scrape Websites & Summarize Articles") + with gr.Row(): + with gr.Column(): + scrape_method = gr.Radio( + ["Individual URLs", "Sitemap", "URL Level", "Recursive Scraping"], + label="Scraping Method", + value="Individual URLs" + ) + url_input = gr.Textbox( + label="Article URLs or Base URL", + placeholder="Enter article URLs here, one per line, or base URL for sitemap/URL level/recursive scraping", + lines=5 + ) + url_level = gr.Slider( + minimum=1, + maximum=10, + step=1, + label="URL Level (for URL Level scraping)", + value=2, + visible=False + ) + max_pages = gr.Slider( + minimum=1, + maximum=100, + step=1, + label="Maximum Pages to Scrape (for Recursive Scraping)", + value=10, + visible=False + ) + max_depth = gr.Slider( + minimum=1, + maximum=10, + step=1, + label="Maximum Depth (for Recursive Scraping)", + value=3, + visible=False + ) + custom_article_title_input = gr.Textbox( + label="Custom Article Titles (Optional, one per line)", + placeholder="Enter custom titles for the articles, one per line", + lines=5 + ) + with gr.Row(): + summarize_checkbox = gr.Checkbox(label="Summarize/Analyze Articles", value=False) + custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt", value=False, visible=True) + preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt", value=False, visible=True) + with gr.Row(): + temp_slider = gr.Slider(0.1, 2.0, 0.7, label="Temperature") + + # Initialize state variables for pagination + current_page_state = gr.State(value=1) + total_pages_state = gr.State(value=1) + with gr.Row(): + # Add pagination controls + preset_prompt = gr.Dropdown( + label="Select Preset Prompt", + choices=[], + visible=False + ) + with gr.Row(): + prev_page_button = gr.Button("Previous Page", visible=False) + page_display = gr.Markdown("Page 1 of X", visible=False) + next_page_button = gr.Button("Next Page", visible=False) + + with gr.Row(): + website_custom_prompt_input = gr.Textbox( + label="Custom Prompt", + placeholder="Enter custom prompt here", + lines=3, + visible=False + ) + with gr.Row(): + system_prompt_input = gr.Textbox( + label="System Prompt", + value="""You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST] + **Bulleted Note Creation Guidelines** + + **Headings**: + - Based on referenced topics, not categories like quotes or terms + - Surrounded by **bold** formatting + - Not listed as bullet points + - No space between headings and list items underneath + + **Emphasis**: + - **Important terms** set in bold font + - **Text ending in a colon**: also bolded + + **Review**: + - Ensure adherence to specified format + - Do not reference these instructions in your response.[INST] {{ .Prompt }} [/INST] + """, + lines=3, + visible=False + ) + + # Refactored API selection dropdown + api_name_input = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" + ) + api_key_input = gr.Textbox( + label="API Key (Mandatory if API Name is specified)", + placeholder="Enter your API key here; Ignore if using Local API or Built-in API", + type="password" + ) + custom_cookies_input = gr.Textbox( + label="Custom Cookies (JSON format)", + placeholder="Enter custom cookies in JSON format", + lines=3, + visible=True + ) + keywords_input = gr.Textbox( + label="Keywords", + placeholder="Enter keywords here (comma-separated)", + value="default,no_keyword_set", + visible=True + ) + bookmarks_file_input = gr.File( + label="Upload Bookmarks File/CSV", + type="filepath", + file_types=[".json", ".html", ".csv"], # Added .csv + visible=True + ) + gr.Markdown(""" + Supported file formats: + - Chrome/Edge bookmarks (JSON) + - Firefox bookmarks (HTML) + - CSV file with 'url' column (optionally 'title' or 'name' column) + """) + parsed_urls_output = gr.Textbox( + label="Parsed URLs", + placeholder="URLs will be displayed here after uploading a file.", + lines=10, + interactive=False, + visible=False + ) + + scrape_button = gr.Button("Scrape and Summarize") + + with gr.Column(): + progress_output = gr.Textbox(label="Progress", lines=3) + result_output = gr.Textbox(label="Result", lines=20) + + def update_ui_for_scrape_method(method): + url_level_update = gr.update(visible=(method == "URL Level")) + max_pages_update = gr.update(visible=(method == "Recursive Scraping")) + max_depth_update = gr.update(visible=(method == "Recursive Scraping")) + url_input_update = gr.update( + label="Article URLs" if method == "Individual URLs" else "Base URL", + placeholder="Enter article URLs here, one per line" if method == "Individual URLs" else "Enter the base URL for scraping" + ) + return url_level_update, max_pages_update, max_depth_update, url_input_update + + scrape_method.change( + fn=update_ui_for_scrape_method, + inputs=[scrape_method], + outputs=[url_level, max_pages, max_depth, url_input] + ) + + custom_prompt_checkbox.change( + fn=lambda x: (gr.update(visible=x), gr.update(visible=x)), + inputs=[custom_prompt_checkbox], + outputs=[website_custom_prompt_input, system_prompt_input] + ) + + def on_preset_prompt_checkbox_change(is_checked): + if is_checked: + prompts, total_pages, current_page = list_prompts(page=1, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return ( + gr.update(visible=True, interactive=True, choices=prompts), # preset_prompt + gr.update(visible=True), # prev_page_button + gr.update(visible=True), # next_page_button + gr.update(value=page_display_text, visible=True), # page_display + current_page, # current_page_state + total_pages # total_pages_state + ) + else: + return ( + gr.update(visible=False, interactive=False), # preset_prompt + gr.update(visible=False), # prev_page_button + gr.update(visible=False), # next_page_button + gr.update(visible=False), # page_display + 1, # current_page_state + 1 # total_pages_state + ) + + preset_prompt_checkbox.change( + fn=on_preset_prompt_checkbox_change, + inputs=[preset_prompt_checkbox], + outputs=[preset_prompt, prev_page_button, next_page_button, page_display, current_page_state, total_pages_state] + ) + + def on_prev_page_click(current_page, total_pages): + new_page = max(current_page - 1, 1) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + prev_page_button.click( + fn=on_prev_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def on_next_page_click(current_page, total_pages): + new_page = min(current_page + 1, total_pages) + prompts, total_pages, current_page = list_prompts(page=new_page, per_page=20) + page_display_text = f"Page {current_page} of {total_pages}" + return gr.update(choices=prompts), gr.update(value=page_display_text), current_page + + next_page_button.click( + fn=on_next_page_click, + inputs=[current_page_state, total_pages_state], + outputs=[preset_prompt, page_display, current_page_state] + ) + + def update_prompts(preset_name): + prompts = update_user_prompt(preset_name) + return ( + gr.update(value=prompts["user_prompt"], visible=True), + gr.update(value=prompts["system_prompt"], visible=True) + ) + + preset_prompt.change( + update_prompts, + inputs=[preset_prompt], + outputs=[website_custom_prompt_input, system_prompt_input] + ) + + def parse_bookmarks(file_path): + """ + Parses the uploaded bookmarks file and extracts URLs. + + Args: + file_path (str): Path to the uploaded bookmarks file. + + Returns: + str: Formatted string of extracted URLs or error message. + """ + try: + bookmarks = collect_bookmarks(file_path) + # Extract URLs + urls = [] + for value in bookmarks.values(): + if isinstance(value, list): + urls.extend(value) + elif isinstance(value, str): + urls.append(value) + if not urls: + return "No URLs found in the bookmarks file." + # Format URLs for display + formatted_urls = "\n".join(urls) + return formatted_urls + except Exception as e: + logging.error(f"Error parsing bookmarks file: {str(e)}") + return f"Error parsing bookmarks file: {str(e)}" + + def show_parsed_urls(urls_file): + """ + Determines whether to show the parsed URLs output. + + Args: + urls_file: Uploaded file object. + + Returns: + Tuple indicating visibility and content of parsed_urls_output. + """ + if urls_file is None: + return gr.update(visible=False), "" + + file_path = urls_file.name + try: + # Use the unified collect_urls_from_file function + parsed_urls = collect_urls_from_file(file_path) + + # Format the URLs for display + formatted_urls = [] + for name, urls in parsed_urls.items(): + if isinstance(urls, list): + for url in urls: + formatted_urls.append(f"{name}: {url}") + else: + formatted_urls.append(f"{name}: {urls}") + + return gr.update(visible=True), "\n".join(formatted_urls) + except Exception as e: + return gr.update(visible=True), f"Error parsing file: {str(e)}" + + # Connect the parsing function to the file upload event + bookmarks_file_input.change( + fn=show_parsed_urls, + inputs=[bookmarks_file_input], + outputs=[parsed_urls_output, parsed_urls_output] + ) + + async def scrape_and_summarize_wrapper( + scrape_method: str, + url_input: str, + url_level: Optional[int], + max_pages: int, + max_depth: int, + summarize_checkbox: bool, + custom_prompt: Optional[str], + api_name: Optional[str], + api_key: Optional[str], + keywords: str, + custom_titles: Optional[str], + system_prompt: Optional[str], + temperature: float, + custom_cookies: Optional[str], + bookmarks_file, + progress: gr.Progress = gr.Progress() + ) -> str: + try: + result: List[Dict[str, Any]] = [] + + # Handle bookmarks file if provided + if bookmarks_file is not None: + bookmarks = collect_bookmarks(bookmarks_file.name) + # Extract URLs from bookmarks + urls_from_bookmarks = [] + for value in bookmarks.values(): + if isinstance(value, list): + urls_from_bookmarks.extend(value) + elif isinstance(value, str): + urls_from_bookmarks.append(value) + if scrape_method == "Individual URLs": + url_input = "\n".join(urls_from_bookmarks) + else: + if urls_from_bookmarks: + url_input = urls_from_bookmarks[0] + else: + return convert_json_to_markdown(json.dumps({"error": "No URLs found in the bookmarks file."})) + + # Handle custom cookies + custom_cookies_list = None + if custom_cookies: + try: + custom_cookies_list = json.loads(custom_cookies) + if not isinstance(custom_cookies_list, list): + custom_cookies_list = [custom_cookies_list] + except json.JSONDecodeError as e: + return convert_json_to_markdown(json.dumps({"error": f"Invalid JSON format for custom cookies: {e}"})) + + if scrape_method == "Individual URLs": + result = await scrape_and_summarize_multiple(url_input, custom_prompt, api_name, api_key, keywords, + custom_titles, system_prompt, summarize_checkbox, custom_cookies=custom_cookies_list) + elif scrape_method == "Sitemap": + result = await asyncio.to_thread(scrape_from_sitemap, url_input) + elif scrape_method == "URL Level": + if url_level is None: + return convert_json_to_markdown( + json.dumps({"error": "URL level is required for URL Level scraping."})) + result = await asyncio.to_thread(scrape_by_url_level, url_input, url_level) + elif scrape_method == "Recursive Scraping": + result = await recursive_scrape(url_input, max_pages, max_depth, progress.update, delay=1.0, + custom_cookies=custom_cookies_list) + else: + return convert_json_to_markdown(json.dumps({"error": f"Unknown scraping method: {scrape_method}"})) + + # Ensure result is always a list of dictionaries + if isinstance(result, dict): + result = [result] + elif isinstance(result, list): + if all(isinstance(item, str) for item in result): + # Convert list of strings to list of dictionaries + result = [{"content": item} for item in result] + elif not all(isinstance(item, dict) for item in result): + raise ValueError("Not all items in result are dictionaries or strings") + else: + raise ValueError(f"Unexpected result type: {type(result)}") + + # Ensure all items in result are dictionaries + if not all(isinstance(item, dict) for item in result): + raise ValueError("Not all items in result are dictionaries") + + if summarize_checkbox: + total_articles = len(result) + for i, article in enumerate(result): + progress.update(f"Summarizing article {i + 1}/{total_articles}") + content = article.get('content', '') + if content: + summary = await asyncio.to_thread(summarize, content, custom_prompt, api_name, api_key, + temperature, system_prompt) + article['summary'] = summary + else: + article['summary'] = "No content available to summarize." + + # Concatenate all content + all_content = "\n\n".join( + [f"# {article.get('title', 'Untitled')}\n\n{article.get('content', '')}\n\n" + + (f"Summary: {article.get('summary', '')}" if summarize_checkbox else "") + for article in result]) + + # Collect all unique URLs + all_urls = list(set(article.get('url', '') for article in result if article.get('url'))) + + # Structure the output for the entire website collection + website_collection = { + "base_url": url_input, + "scrape_method": scrape_method, + "summarization_performed": summarize_checkbox, + "api_used": api_name if summarize_checkbox else None, + "keywords": keywords if summarize_checkbox else None, + "url_level": url_level if scrape_method == "URL Level" else None, + "max_pages": max_pages if scrape_method == "Recursive Scraping" else None, + "max_depth": max_depth if scrape_method == "Recursive Scraping" else None, + "total_articles_scraped": len(result), + "urls_scraped": all_urls, + "content": all_content + } + + # Convert the JSON to markdown and return + return convert_json_to_markdown(json.dumps(website_collection, indent=2)) + except Exception as e: + return convert_json_to_markdown(json.dumps({"error": f"An error occurred: {str(e)}"})) + + # Update the scrape_button.click to include the temperature parameter + scrape_button.click( + fn=lambda *args: asyncio.run(scrape_and_summarize_wrapper(*args)), + inputs=[scrape_method, url_input, url_level, max_pages, max_depth, summarize_checkbox, + website_custom_prompt_input, api_name_input, api_key_input, keywords_input, + custom_article_title_input, system_prompt_input, temp_slider, + custom_cookies_input, bookmarks_file_input], + outputs=[result_output] + ) + + +def convert_json_to_markdown(json_str: str) -> str: + """ + Converts the JSON output from the scraping process into a markdown format. + + Args: + json_str (str): JSON-formatted string containing the website collection data + + Returns: + str: Markdown-formatted string of the website collection data + """ + try: + # Parse the JSON string + data = json.loads(json_str) + + # Check if there's an error in the JSON + if "error" in data: + return f"# Error\n\n{data['error']}" + + # Start building the markdown string + markdown = f"# Website Collection: {data['base_url']}\n\n" + + # Add metadata + markdown += "## Metadata\n\n" + markdown += f"- **Scrape Method:** {data['scrape_method']}\n" + markdown += f"- **API Used:** {data['api_used']}\n" + markdown += f"- **Keywords:** {data['keywords']}\n" + if data.get('url_level') is not None: + markdown += f"- **URL Level:** {data['url_level']}\n" + if data.get('max_pages') is not None: + markdown += f"- **Maximum Pages:** {data['max_pages']}\n" + if data.get('max_depth') is not None: + markdown += f"- **Maximum Depth:** {data['max_depth']}\n" + markdown += f"- **Total Articles Scraped:** {data['total_articles_scraped']}\n\n" + + # Add URLs Scraped + markdown += "## URLs Scraped\n\n" + for url in data['urls_scraped']: + markdown += f"- {url}\n" + markdown += "\n" + + # Add the content + markdown += "## Content\n\n" + markdown += data['content'] + + return markdown + + except json.JSONDecodeError: + return "# Error\n\nInvalid JSON string provided." + except KeyError as e: + return f"# Error\n\nMissing key in JSON data: {str(e)}" + except Exception as e: + return f"# Error\n\nAn unexpected error occurred: {str(e)}" + +# +# End of File +######################################################################################################################## diff --git a/App_Function_Libraries/Gradio_UI/Workflows_tab.py b/App_Function_Libraries/Gradio_UI/Workflows_tab.py new file mode 100644 index 0000000000000000000000000000000000000000..999f4869f0be26d213f6cb8c16e56ab6c791cd19 --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/Workflows_tab.py @@ -0,0 +1,190 @@ +# Chat_Workflows.py +# Description: Gradio UI for Chat Workflows +# +# Imports +import json +import logging +from pathlib import Path +# +# External Imports +import gradio as gr +# +# Local Imports +from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper, search_conversations, \ + load_conversation +from App_Function_Libraries.Chat.Chat_Functions import save_chat_history_to_db_wrapper +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +# +############################################################################################################ +# +# Functions: + +# Load workflows from a JSON file +json_path = Path('./Helper_Scripts/Workflows/Workflows.json') +with json_path.open('r') as f: + workflows = json.load(f) + + +def chat_workflows_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + with gr.TabItem("Chat Workflows", visible=True): + gr.Markdown("# Workflows using LLMs") + chat_history = gr.State([]) + media_content = gr.State({}) + selected_parts = gr.State([]) + conversation_id = gr.State(None) + workflow_state = gr.State({"current_step": 0, "max_steps": 0, "conversation_id": None}) + + with gr.Row(): + with gr.Column(): + workflow_selector = gr.Dropdown(label="Select Workflow", choices=[wf['name'] for wf in workflows]) + # Refactored API selection dropdown + api_selector = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Interaction (Optional)" + ) + api_key_input = gr.Textbox(label="API Key (optional)", type="password") + temperature = gr.Slider(label="Temperature", minimum=0.00, maximum=1.0, step=0.05, value=0.7) + save_conversation = gr.Checkbox(label="Save Conversation", value=False) + with gr.Column(): + gr.Markdown("Placeholder") + with gr.Row(): + with gr.Column(): + conversation_search = gr.Textbox(label="Search Conversations") + search_conversations_btn = gr.Button("Search Conversations") + with gr.Column(): + previous_conversations = gr.Dropdown(label="Select Conversation", choices=[], interactive=True) + load_conversations_btn = gr.Button("Load Selected Conversation") + with gr.Row(): + with gr.Column(): + context_input = gr.Textbox(label="Initial Context", lines=5) + chatbot = gr.Chatbot(label="Workflow Chat") + msg = gr.Textbox(label="Your Input") + submit_btn = gr.Button("Submit") + clear_btn = gr.Button("Clear Chat") + chat_media_name = gr.Textbox(label="Custom Chat Name(optional)") + save_btn = gr.Button("Save Chat to Database") + save_status = gr.Textbox(label="Save Status", interactive=False) + + def update_workflow_ui(workflow_name): + if not workflow_name: + return {"current_step": 0, "max_steps": 0, "conversation_id": None}, "", [] + selected_workflow = next((wf for wf in workflows if wf['name'] == workflow_name), None) + if selected_workflow: + num_prompts = len(selected_workflow['prompts']) + context = selected_workflow.get('context', '') + first_prompt = selected_workflow['prompts'][0] + initial_chat = [(None, f"{first_prompt}")] + logging.info(f"Initializing workflow: {workflow_name} with {num_prompts} steps") + return {"current_step": 0, "max_steps": num_prompts, "conversation_id": None}, context, initial_chat + else: + logging.error(f"Selected workflow not found: {workflow_name}") + return {"current_step": 0, "max_steps": 0, "conversation_id": None}, "", [] + + def process_workflow_step(message, history, context, workflow_name, api_endpoint, api_key, workflow_state, + save_conv, temp): + logging.info(f"Process workflow step called with message: {message}") + logging.info(f"Current workflow state: {workflow_state}") + try: + selected_workflow = next((wf for wf in workflows if wf['name'] == workflow_name), None) + if not selected_workflow: + logging.error(f"Selected workflow not found: {workflow_name}") + return history, workflow_state, gr.update(interactive=True) + + current_step = workflow_state["current_step"] + max_steps = workflow_state["max_steps"] + + logging.info(f"Current step: {current_step}, Max steps: {max_steps}") + + if current_step >= max_steps: + logging.info("Workflow completed, disabling input") + return history, workflow_state, gr.update(interactive=False) + + prompt = selected_workflow['prompts'][current_step] + full_message = f"{context}\n\nStep {current_step + 1}: {prompt}\nUser: {message}" + + logging.info(f"Calling chat_wrapper with full_message: {full_message[:100]}...") + bot_message, new_history, new_conversation_id = chat_wrapper( + full_message, history, media_content.value, selected_parts.value, + api_endpoint, api_key, "", workflow_state["conversation_id"], + save_conv, temp, "You are a helpful assistant guiding through a workflow." + ) + + logging.info(f"Received bot_message: {bot_message[:100]}...") + + next_step = current_step + 1 + new_workflow_state = { + "current_step": next_step, + "max_steps": max_steps, + "conversation_id": new_conversation_id + } + + if next_step >= max_steps: + logging.info("Workflow completed after this step") + return new_history, new_workflow_state, gr.update(interactive=False) + else: + next_prompt = selected_workflow['prompts'][next_step] + new_history.append((None, f"Step {next_step + 1}: {next_prompt}")) + logging.info(f"Moving to next step: {next_step}") + return new_history, new_workflow_state, gr.update(interactive=True) + except Exception as e: + logging.error(f"Error in process_workflow_step: {str(e)}") + return history, workflow_state, gr.update(interactive=True) + + workflow_selector.change( + update_workflow_ui, + inputs=[workflow_selector], + outputs=[workflow_state, context_input, chatbot] + ) + + submit_btn.click( + process_workflow_step, + inputs=[msg, chatbot, context_input, workflow_selector, api_selector, api_key_input, workflow_state, + save_conversation, temperature], + outputs=[chatbot, workflow_state, msg] + ).then( + lambda: gr.update(value=""), + outputs=[msg] + ) + + clear_btn.click( + lambda: ([], {"current_step": 0, "max_steps": 0, "conversation_id": None}, ""), + outputs=[chatbot, workflow_state, context_input] + ) + + save_btn.click( + save_chat_history_to_db_wrapper, + inputs=[chatbot, conversation_id, media_content, chat_media_name], + outputs=[conversation_id, save_status] + ) + + search_conversations_btn.click( + search_conversations, + inputs=[conversation_search], + outputs=[previous_conversations] + ) + + load_conversations_btn.click( + lambda: ([], {"current_step": 0, "max_steps": 0, "conversation_id": None}, ""), + outputs=[chatbot, workflow_state, context_input] + ).then( + load_conversation, + inputs=[previous_conversations], + outputs=[chatbot, conversation_id] + ) + + return workflow_selector, api_selector, api_key_input, context_input, chatbot, msg, submit_btn, clear_btn, save_btn + +# +# End of script +############################################################################################################ diff --git a/App_Function_Libraries/Gradio_UI/Writing_tab.py b/App_Function_Libraries/Gradio_UI/Writing_tab.py index eb03119af2fb485537d1914d8d19db8daf83c340..72b8dd48d2b2f5ecba7e163ef06700ca61809f64 100644 --- a/App_Function_Libraries/Gradio_UI/Writing_tab.py +++ b/App_Function_Libraries/Gradio_UI/Writing_tab.py @@ -4,11 +4,16 @@ # Imports # # External Imports +import logging + import gradio as gr import textstat # # Local Imports from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name + + # ######################################################################################################################## # @@ -42,6 +47,16 @@ def grammar_style_check(input_text, custom_prompt, api_name, api_key, system_pro def create_grammar_style_check_tab(): with gr.TabItem("Grammar and Style Check", visible=True): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None with gr.Row(): with gr.Column(): gr.Markdown("# Grammar and Style Check") @@ -74,11 +89,11 @@ def create_grammar_style_check_tab(): inputs=[custom_prompt_checkbox], outputs=[custom_prompt_input, system_prompt_input] ) + # Refactored API selection dropdown api_name_input = gr.Dropdown( - choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter", - "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM","ollama", "HuggingFace", "Custom-OpenAI-API"], - value=None, - label="API for Grammar Check" + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Analysis (Optional)" ) api_key_input = gr.Textbox(label="API Key (if not set in Config_Files/config.txt)", placeholder="Enter your API key here", type="password") @@ -302,63 +317,63 @@ def create_document_feedback_tab(): with gr.Row(): compare_button = gr.Button("Compare Feedback") - feedback_history = gr.State([]) - - def add_custom_persona(name, description): - updated_choices = persona_dropdown.choices + [name] - persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" - return gr.update(choices=updated_choices) - - def update_feedback_history(current_text, persona, feedback): - # Ensure feedback_history.value is initialized and is a list - if feedback_history.value is None: - feedback_history.value = [] - - history = feedback_history.value - - # Append the new entry to the history - history.append({"text": current_text, "persona": persona, "feedback": feedback}) - - # Keep only the last 5 entries in the history - feedback_history.value = history[-10:] - - # Generate and return the updated HTML - return generate_feedback_history_html(feedback_history.value) - - def compare_feedback(text, selected_personas, api_name, api_key): - results = [] - for persona in selected_personas: - feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) - results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") - return "\n".join(results) - - add_custom_persona_button.click( - fn=add_custom_persona, - inputs=[custom_persona_name, custom_persona_description], - outputs=persona_dropdown - ) - - get_feedback_button.click( - fn=lambda text, persona, aspect, api_name, api_key: ( - generate_writing_feedback(text, persona, aspect, api_name, api_key), - calculate_readability(text), - update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) - ), - inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], - outputs=[feedback_output, readability_output, feedback_history_display] - ) - - compare_button.click( - fn=compare_feedback, - inputs=[input_text, compare_personas, api_name_input, api_key_input], - outputs=feedback_output - ) - - generate_prompt_button.click( - fn=generate_writing_prompt, - inputs=[persona_dropdown, api_name_input, api_key_input], - outputs=input_text - ) + feedback_history = gr.State([]) + + def add_custom_persona(name, description): + updated_choices = persona_dropdown.choices + [name] + persona_prompts[name] = f"As {name}, {description}, provide feedback on the following text:" + return gr.update(choices=updated_choices) + + def update_feedback_history(current_text, persona, feedback): + # Ensure feedback_history.value is initialized and is a list + if feedback_history.value is None: + feedback_history.value = [] + + history = feedback_history.value + + # Append the new entry to the history + history.append({"text": current_text, "persona": persona, "feedback": feedback}) + + # Keep only the last 5 entries in the history + feedback_history.value = history[-10:] + + # Generate and return the updated HTML + return generate_feedback_history_html(feedback_history.value) + + def compare_feedback(text, selected_personas, api_name, api_key): + results = [] + for persona in selected_personas: + feedback = generate_writing_feedback(text, persona, "Overall", api_name, api_key) + results.append(f"### {persona}'s Feedback:\n{feedback}\n\n") + return "\n".join(results) + + add_custom_persona_button.click( + fn=add_custom_persona, + inputs=[custom_persona_name, custom_persona_description], + outputs=persona_dropdown + ) + + get_feedback_button.click( + fn=lambda text, persona, aspect, api_name, api_key: ( + generate_writing_feedback(text, persona, aspect, api_name, api_key), + calculate_readability(text), + update_feedback_history(text, persona, generate_writing_feedback(text, persona, aspect, api_name, api_key)) + ), + inputs=[input_text, persona_dropdown, aspect_dropdown, api_name_input, api_key_input], + outputs=[feedback_output, readability_output, feedback_history_display] + ) + + compare_button.click( + fn=compare_feedback, + inputs=[input_text, compare_personas, api_name_input, api_key_input], + outputs=feedback_output + ) + + generate_prompt_button.click( + fn=generate_writing_prompt, + inputs=[persona_dropdown, api_name_input, api_key_input], + outputs=input_text + ) return input_text, feedback_output, readability_output, feedback_history_display diff --git a/App_Function_Libraries/Gradio_UI/XML_Ingestion_Tab.py b/App_Function_Libraries/Gradio_UI/XML_Ingestion_Tab.py new file mode 100644 index 0000000000000000000000000000000000000000..a837c2763e2c8ebbce5d6166db57368b6487356e --- /dev/null +++ b/App_Function_Libraries/Gradio_UI/XML_Ingestion_Tab.py @@ -0,0 +1,64 @@ +# XML_Ingestion_Tab.py +# Description: This file contains functions for reading and writing XML files. +# +# Imports +import logging +# +# External Imports +import gradio as gr +# +# Local Imports +from App_Function_Libraries.Utils.Utils import default_api_endpoint, global_api_endpoints, format_api_name +from App_Function_Libraries.Plaintext.XML_Ingestion_Lib import import_xml_handler +# +####################################################################################################################### +# +# Functions: + +def create_xml_import_tab(): + try: + default_value = None + if default_api_endpoint: + if default_api_endpoint in global_api_endpoints: + default_value = format_api_name(default_api_endpoint) + else: + logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") + except Exception as e: + logging.error(f"Error setting default API endpoint: {str(e)}") + default_value = None + + with gr.TabItem("Import XML Files", visible=True): + with gr.Row(): + with gr.Column(): + gr.Markdown("# Import XML Files") + gr.Markdown("Upload XML files for import") + import_file = gr.File(label="Upload XML file", file_types=[".xml"]) + title_input = gr.Textbox(label="Title", placeholder="Enter the title of the content") + keywords_input = gr.Textbox(label="Keywords", placeholder="Enter keywords, comma-separated") + system_prompt_input = gr.Textbox(label="System Prompt (for Summarization)", lines=3, + value="""[Your default system prompt here]""") + custom_prompt_input = gr.Textbox(label="Custom User Prompt", + placeholder="Enter a custom user prompt for summarization (optional)") + auto_summarize_checkbox = gr.Checkbox(label="Auto-summarize/analyze", value=False) + api_name_input = gr.Dropdown( + choices=["None"] + [format_api_name(api) for api in global_api_endpoints], + value=default_value, + label="API for Summarization/Analysis (Optional)" + ) + api_key_input = gr.Textbox(label="API Key", type="password") + import_button = gr.Button("Import XML File") + with gr.Column(): + import_output = gr.Textbox(label="Import Status") + + import_button.click( + fn=import_xml_handler, + inputs=[import_file, title_input, keywords_input, system_prompt_input, + custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input], + outputs=import_output + ) + + return import_file, title_input, keywords_input, system_prompt_input, custom_prompt_input, auto_summarize_checkbox, api_name_input, api_key_input, import_button, import_output + +# +# End of XML_Ingestion_Tab.py +####################################################################################################################### diff --git a/App_Function_Libraries/Gradio_UI/tldw_app_logs.json b/App_Function_Libraries/Gradio_UI/tldw_app_logs.json new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/App_Function_Libraries/LLM_API_Calls_Local.py b/App_Function_Libraries/LLM_API_Calls_Local.py index 6d50abd2f65b3d86df035040741ed29249c9bbfd..59d943065b7f0b7423531d605546fbacb0bf2cb1 100644 --- a/App_Function_Libraries/LLM_API_Calls_Local.py +++ b/App_Function_Libraries/LLM_API_Calls_Local.py @@ -251,8 +251,6 @@ def chat_with_kobold(input_data, api_key, custom_prompt_input, kobold_api_ip="ht # FIXME # Values literally c/p from the api docs.... data = { - "max_context_length": 8096, - "max_length": 4096, "prompt": kobold_prompt, "temperature": 0.7, #"top_p": 0.9, diff --git a/App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py b/App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py index 9eaa59b62d69cec8de509f880da6d0c816093267..f8dc94658a4be01b70986ed55a8e51ed8a1579db 100644 --- a/App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py +++ b/App_Function_Libraries/Local_LLM/Local_LLM_Inference_Engine_Lib.py @@ -14,21 +14,19 @@ #################### # Import necessary libraries #import atexit -import glob import logging import os import re import signal import subprocess import sys -import time +from pathlib import Path from typing import List, Optional # # Import 3rd-pary Libraries import requests # # Import Local -from App_Function_Libraries.Web_Scraping.Article_Summarization_Lib import * from App_Function_Libraries.Utils.Utils import download_file # ####################################################################################################################### @@ -158,21 +156,25 @@ def get_gguf_llamafile_files(directory: str) -> List[str]: """ logging.debug(f"Scanning directory: {directory}") # Debug print for directory - # Print all files in the directory for debugging - all_files = os.listdir(directory) - logging.debug(f"All files in directory: {all_files}") - - pattern_gguf = os.path.join(directory, "*.gguf") - pattern_llamafile = os.path.join(directory, "*.llamafile") + try: + dir_path = Path(directory) + all_files = list(dir_path.iterdir()) + logging.debug(f"All files in directory: {[str(f) for f in all_files]}") + except Exception as e: + logging.error(f"Failed to list files in directory {directory}: {e}") + return [] - gguf_files = glob.glob(pattern_gguf) - llamafile_files = glob.glob(pattern_llamafile) + try: + gguf_files = list(dir_path.glob("*.gguf")) + llamafile_files = list(dir_path.glob("*.llamafile")) - # Debug: Print the files found - logging.debug(f"Found .gguf files: {gguf_files}") - logging.debug(f"Found .llamafile files: {llamafile_files}") + logging.debug(f"Found .gguf files: {[str(f) for f in gguf_files]}") + logging.debug(f"Found .llamafile files: {[str(f) for f in llamafile_files]}") - return [os.path.basename(f) for f in gguf_files + llamafile_files] + return [f.name for f in gguf_files + llamafile_files] + except Exception as e: + logging.error(f"Error during glob operations in directory {directory}: {e}") + return [] # Initialize process with type annotation diff --git a/App_Function_Libraries/Local_LLM/Local_LLM_ollama.py b/App_Function_Libraries/Local_LLM/Local_LLM_ollama.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc395ae93582556854b378bac67fda111832259 --- /dev/null +++ b/App_Function_Libraries/Local_LLM/Local_LLM_ollama.py @@ -0,0 +1,201 @@ +import platform + +import gradio as gr +import subprocess +import psutil +import os +import signal +import logging +import threading +import shutil + +# Configure Logging +# logging.basicConfig( +# level=logging.DEBUG, # Set to DEBUG to capture all levels of logs +# format='%(asctime)s - %(levelname)s - %(message)s', +# handlers=[ +# logging.FileHandler("app.log"), +# logging.StreamHandler() +# ] +# ) + +def is_ollama_installed(): + """ + Checks if the 'ollama' executable is available in the system's PATH. + Returns True if installed, False otherwise. + """ + return shutil.which('ollama') is not None + +def get_ollama_models(): + """ + Retrieves available Ollama models by executing 'ollama list'. + Returns a list of model names or an empty list if an error occurs. + """ + try: + result = subprocess.run(['ollama', 'list'], capture_output=True, text=True, check=True, timeout=10) + models = result.stdout.strip().split('\n')[1:] # Skip header + model_names = [model.split()[0] for model in models if model.strip()] + logging.debug(f"Available Ollama models: {model_names}") + return model_names + except FileNotFoundError: + logging.error("Ollama executable not found. Please ensure Ollama is installed and in your PATH.") + return [] + except subprocess.TimeoutExpired: + logging.error("Ollama 'list' command timed out.") + return [] + except subprocess.CalledProcessError as e: + logging.error(f"Error executing Ollama 'list': {e}") + return [] + except Exception as e: + logging.error(f"Unexpected error in get_ollama_models: {e}") + return [] + +def pull_ollama_model(model_name): + """ + Pulls the specified Ollama model if Ollama is installed. + """ + if not is_ollama_installed(): + logging.error("Ollama is not installed.") + return "Failed to pull model: Ollama is not installed or not in your PATH." + + try: + subprocess.run(['ollama', 'pull', model_name], check=True, timeout=300) # Adjust timeout as needed + logging.info(f"Successfully pulled model: {model_name}") + return f"Successfully pulled model: {model_name}" + except subprocess.TimeoutExpired: + logging.error(f"Pulling model '{model_name}' timed out.") + return f"Failed to pull model '{model_name}': Operation timed out." + except subprocess.CalledProcessError as e: + logging.error(f"Failed to pull model '{model_name}': {e}") + return f"Failed to pull model '{model_name}': {e}" + except FileNotFoundError: + logging.error("Ollama executable not found. Please ensure Ollama is installed and in your PATH.") + return "Failed to pull model: Ollama executable not found." + except Exception as e: + logging.error(f"Unexpected error in pull_ollama_model: {e}") + return f"Failed to pull model '{model_name}': {e}" + +def serve_ollama_model(model_name, port): + """ + Serves the specified Ollama model on the given port if Ollama is installed. + """ + if not is_ollama_installed(): + logging.error("Ollama is not installed.") + return "Error: Ollama is not installed or not in your PATH." + + try: + # Check if a server is already running on the specified port + for conn in psutil.net_connections(): + if conn.laddr.port == int(port): + logging.warning(f"Port {port} is already in use.") + return f"Error: Port {port} is already in use. Please choose a different port." + + # Start the Ollama server + cmd = ['ollama', 'serve', model_name, '--port', str(port)] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logging.info(f"Started Ollama server for model '{model_name}' on port {port}. PID: {process.pid}") + return f"Started Ollama server for model '{model_name}' on port {port}. Process ID: {process.pid}" + except FileNotFoundError: + logging.error("Ollama executable not found.") + return "Error: Ollama executable not found. Please ensure Ollama is installed and in your PATH." + except Exception as e: + logging.error(f"Error starting Ollama server: {e}") + return f"Error starting Ollama server: {e}" + +def stop_ollama_server(pid): + """ + Stops the Ollama server with the specified process ID if Ollama is installed. + """ + if not is_ollama_installed(): + logging.error("Ollama is not installed.") + return "Error: Ollama is not installed or not in your PATH." + + try: + if platform.system() == "Windows": + subprocess.run(['taskkill', '/F', '/PID', str(pid)], check=True) + elif platform.system() in ["Linux", "Darwin"]: + os.kill(pid, signal.SIGTERM) + logging.info(f"Stopped Ollama server with PID {pid}") + return f"Stopped Ollama server with PID {pid}" + except ProcessLookupError: + logging.warning(f"No process found with PID {pid}") + return f"No process found with PID {pid}" + except Exception as e: + logging.error(f"Error stopping Ollama server: {e}") + return f"Error stopping Ollama server: {e}" + +def create_ollama_tab(): + """ + Creates the Ollama Model Serving tab in the Gradio interface with lazy loading. + """ + ollama_installed = is_ollama_installed() + + with gr.Tab("Ollama Model Serving"): + if not ollama_installed: + gr.Markdown( + "# Ollama Model Serving\n\n" + "**Ollama is not installed or not found in your PATH. Please install Ollama to use this feature.**" + ) + return # Exit early, no need to add further components + + gr.Markdown("# Ollama Model Serving") + + with gr.Row(): + # Initialize Dropdowns with placeholders + model_list = gr.Dropdown( + label="Available Models", + choices=["Click 'Refresh Model List' to load models"], + value="Click 'Refresh Model List' to load models" + ) + refresh_button = gr.Button("Refresh Model List") + + with gr.Row(): + new_model_name = gr.Textbox(label="Model to Pull", placeholder="Enter model name") + pull_button = gr.Button("Pull Model") + + pull_output = gr.Textbox(label="Pull Status") + + with gr.Row(): + serve_model = gr.Dropdown( + label="Model to Serve", + choices=["Click 'Refresh Model List' to load models"], + value="Click 'Refresh Model List' to load models" + ) + port = gr.Number(label="Port", value=11434, precision=0) + serve_button = gr.Button("Start Server") + + serve_output = gr.Textbox(label="Server Status") + + with gr.Row(): + pid = gr.Number(label="Server Process ID (Enter the PID to stop)", precision=0) + stop_button = gr.Button("Stop Server") + + stop_output = gr.Textbox(label="Stop Status") + + def update_model_lists(): + """ + Retrieves the list of available Ollama models and updates the dropdowns. + """ + models = get_ollama_models() + if models: + return gr.update(choices=models, value=models[0]), gr.update(choices=models, value=models[0]) + else: + return gr.update(choices=["No models found"], value="No models found"), gr.update(choices=["No models found"], value="No models found") + + def async_update_model_lists(): + """ + Asynchronously updates the model lists to prevent blocking. + """ + def task(): + choices1, choices2 = update_model_lists() + model_list.update(choices=choices1['choices'], value=choices1.get('value')) + serve_model.update(choices=choices2['choices'], value=choices2.get('value')) + threading.Thread(target=task).start() + + # Bind the refresh button to the asynchronous update function + refresh_button.click(fn=async_update_model_lists, inputs=[], outputs=[]) + + # Bind the pull, serve, and stop buttons to their respective functions + pull_button.click(fn=pull_ollama_model, inputs=[new_model_name], outputs=[pull_output]) + serve_button.click(fn=serve_ollama_model, inputs=[serve_model, port], outputs=[serve_output]) + stop_button.click(fn=stop_ollama_server, inputs=[pid], outputs=[stop_output]) diff --git a/App_Function_Libraries/PDF/PDF_Ingestion_Lib.py b/App_Function_Libraries/PDF/PDF_Ingestion_Lib.py index 49abc704d9d1dcd866d5eafdeca554b3cc40ed90..b2e4ee712b1dc169a735da4380ff5cd9c7d5dcd1 100644 --- a/App_Function_Libraries/PDF/PDF_Ingestion_Lib.py +++ b/App_Function_Libraries/PDF/PDF_Ingestion_Lib.py @@ -12,13 +12,17 @@ # #################### # Import necessary libraries -import re +from datetime import datetime +import logging import os +import re import shutil import tempfile -from datetime import datetime +# +# Import External Libs import pymupdf -import logging +import pymupdf4llm +from docling.document_converter import DocumentConverter # # Import Local from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords @@ -104,6 +108,28 @@ def extract_text_and_format_from_pdf(pdf_path): raise +def pymupdf4llm_parse_pdf(pdf_path): + """ + Extract text from a PDF file and convert it to Markdown, preserving formatting. + """ + try: + log_counter("pdf_text_extraction_attempt", labels={"file_path": pdf_path}) + start_time = datetime.now() + + markdown_text = pymupdf4llm.to_markdown(pdf_path) + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("pdf_text_extraction_duration", processing_time, labels={"file_path": pdf_path}) + log_counter("pdf_text_extraction_success", labels={"file_path": pdf_path}) + + return markdown_text + except Exception as e: + logging.error(f"Error extracting text and formatting from PDF: {str(e)}") + log_counter("pdf_text_extraction_error", labels={"file_path": pdf_path, "error": str(e)}) + raise + + def extract_metadata_from_pdf(pdf_path): """ Extract metadata from a PDF file using PyMuPDF. @@ -120,7 +146,7 @@ def extract_metadata_from_pdf(pdf_path): return {} -def process_and_ingest_pdf(file, title, author, keywords): +def process_and_ingest_pdf(file, title, author, keywords, parser='pymupdf4llm'): if file is None: log_counter("pdf_ingestion_error", labels={"error": "No file uploaded"}) return "Please select a PDF file to upload." @@ -137,8 +163,19 @@ def process_and_ingest_pdf(file, title, author, keywords): # Copy the contents of the uploaded file to the temporary file shutil.copy(file.name, temp_path) - # Extract text and convert to Markdown - markdown_text = extract_text_and_format_from_pdf(temp_path) + if parser == 'pymupdf': + # Extract text and convert to Markdown + markdown_text = extract_text_and_format_from_pdf(temp_path) + + elif parser == 'pymupdf4llm': + # Extract text and convert to Markdown + markdown_text = pymupdf4llm_parse_pdf(temp_path) + + elif parser == 'docling': + # Extract text and convert to Markdown using Docling + converter = DocumentConverter() + parsed_pdf = converter.convert(temp_path) + markdown_text = parsed_pdf.document.export_to_markdown() # Extract metadata from PDF metadata = extract_metadata_from_pdf(temp_path) @@ -185,7 +222,7 @@ def process_and_ingest_pdf(file, title, author, keywords): return f"Error ingesting PDF file: {str(e)}" -def process_and_cleanup_pdf(file, title, author, keywords): +def process_and_cleanup_pdf(file, title, author, keywords, parser='pymupdf4llm'): if file is None: log_counter("pdf_processing_error", labels={"error": "No file uploaded"}) return "No file uploaded. Please upload a PDF file." @@ -194,7 +231,7 @@ def process_and_cleanup_pdf(file, title, author, keywords): log_counter("pdf_processing_attempt", labels={"file_name": file.name}) start_time = datetime.now() - result = process_and_ingest_pdf(file, title, author, keywords) + result = process_and_ingest_pdf(file, title, author, keywords, parser) end_time = datetime.now() processing_time = (end_time - start_time).total_seconds() diff --git a/App_Function_Libraries/Plaintext/Plaintext_Files.py b/App_Function_Libraries/Plaintext/Plaintext_Files.py index 17adbda9c3ad3a942cda8cb46170abfeeb98c8f4..656b9ba3d399541c5463a15f19770e694b655188 100644 --- a/App_Function_Libraries/Plaintext/Plaintext_Files.py +++ b/App_Function_Libraries/Plaintext/Plaintext_Files.py @@ -2,17 +2,176 @@ # Description: This file contains functions for reading and writing plaintext files. # # Import necessary libraries -import os -import re -from datetime import datetime import logging +import os import tempfile import zipfile +from datetime import datetime + # -# Non-Local Imports +# External Imports +from docx2txt import docx2txt +from pypandoc import convert_file # # Local Imports +from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data +from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram + + # ####################################################################################################################### # # Function Definitions + +def import_plain_text_file(file_path, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Import a single plain text file.""" + try: + log_counter("file_processing_attempt", labels={"file_path": file_path}) + + # Extract title from filename + title = os.path.splitext(os.path.basename(file_path))[0] + + # Determine the file type and convert if necessary + file_extension = os.path.splitext(file_path)[1].lower() + + # Get the content based on file type + try: + if file_extension == '.rtf': + with tempfile.NamedTemporaryFile(suffix='.md', delete=False) as temp_file: + convert_file(file_path, 'md', outputfile=temp_file.name) + file_path = temp_file.name + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + log_counter("rtf_conversion_success", labels={"file_path": file_path}) + elif file_extension == '.docx': + content = docx2txt.process(file_path) + log_counter("docx_conversion_success", labels={"file_path": file_path}) + else: + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + except Exception as e: + logging.error(f"Error reading file content: {str(e)}") + return f"Error reading file content: {str(e)}" + + # Import the content + result = import_data( + content, # Pass the content directly + title, + author, + keywords, + user_prompt, # This is the custom_prompt parameter + None, # No summary - let auto_summarize handle it + auto_summarize, + api_name, + api_key + ) + + log_counter("file_processing_success", labels={"file_path": file_path}) + return result + + except Exception as e: + logging.exception(f"Error processing file {file_path}") + log_counter("file_processing_error", labels={"file_path": file_path, "error": str(e)}) + return f"Error processing file {os.path.basename(file_path)}: {str(e)}" + + +def process_plain_text_zip_file(zip_file, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Process multiple text files from a zip archive.""" + results = [] + try: + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + for filename in os.listdir(temp_dir): + if filename.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + file_path = os.path.join(temp_dir, filename) + result = import_plain_text_file( + file_path=file_path, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + results.append(f"📄 {filename}: {result}") + + return "\n\n".join(results) + except Exception as e: + logging.exception(f"Error processing zip file: {str(e)}") + return f"Error processing zip file: {str(e)}" + + + +def import_file_handler(files, author, keywords, system_prompt, user_prompt, auto_summarize, api_name, api_key): + """Handle the import of one or more files, including zip files.""" + try: + if not files: + log_counter("plaintext_import_error", labels={"error": "No files uploaded"}) + return "No files uploaded." + + # Convert single file to list for consistent processing + if not isinstance(files, list): + files = [files] + + results = [] + for file in files: + log_counter("plaintext_import_attempt", labels={"file_name": file.name}) + + start_time = datetime.now() + + if not os.path.exists(file.name): + log_counter("plaintext_import_error", labels={"error": "File not found", "file_name": file.name}) + results.append(f"❌ File not found: {file.name}") + continue + + if file.name.lower().endswith(('.md', '.txt', '.rtf', '.docx')): + result = import_plain_text_file( + file_path=file.name, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("plaintext_import_success", labels={"file_name": file.name}) + results.append(f"📄 {file.name}: {result}") + + elif file.name.lower().endswith('.zip'): + result = process_plain_text_zip_file( + zip_file=file, + author=author, + keywords=keywords, + system_prompt=system_prompt, + user_prompt=user_prompt, + auto_summarize=auto_summarize, + api_name=api_name, + api_key=api_key + ) + log_counter("zip_import_success", labels={"file_name": file.name}) + results.append(f"📦 {file.name}:\n{result}") + + else: + log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]}) + results.append(f"❌ Unsupported file type: {file.name}") + continue + + end_time = datetime.now() + processing_time = (end_time - start_time).total_seconds() + log_histogram("plaintext_import_duration", processing_time, labels={"file_name": file.name}) + + return "\n\n".join(results) + + except Exception as e: + logging.exception("Error in import_file_handler") + log_counter("plaintext_import_error", labels={"error": str(e)}) + return f"❌ Error during import: {str(e)}" + +# +# End of Plaintext_Files.py +####################################################################################################################### + diff --git a/App_Function_Libraries/Plaintext/XML_Ingestion_Lib.py b/App_Function_Libraries/Plaintext/XML_Ingestion_Lib.py new file mode 100644 index 0000000000000000000000000000000000000000..45fd1b8e34f607285ad8f8b0a339bea73ac2bed8 --- /dev/null +++ b/App_Function_Libraries/Plaintext/XML_Ingestion_Lib.py @@ -0,0 +1,107 @@ +# XML_Ingestion.py +# Description: This file contains functions for reading and writing XML files. +# Imports +import logging +import xml.etree.ElementTree as ET +# +# External Imports +# +# Local Imports +from App_Function_Libraries.Gradio_UI.Import_Functionality import import_data +from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization +from App_Function_Libraries.Chunk_Lib import chunk_xml +from App_Function_Libraries.DB.DB_Manager import add_media_to_database +# +####################################################################################################################### +# +# Functions: + +def xml_to_text(xml_file): + try: + tree = ET.parse(xml_file) + root = tree.getroot() + # Extract text content recursively + text_content = [] + for elem in root.iter(): + if elem.text and elem.text.strip(): + text_content.append(elem.text.strip()) + return '\n'.join(text_content) + except ET.ParseError as e: + logging.error(f"Error parsing XML file: {str(e)}") + return None + + +def import_xml_handler(import_file, title, author, keywords, system_prompt, + custom_prompt, auto_summarize, api_name, api_key): + if not import_file: + return "Please upload an XML file" + + try: + # Parse XML and extract text with structure + tree = ET.parse(import_file.name) + root = tree.getroot() + + # Create chunk options + chunk_options = { + 'method': 'xml', + 'max_size': 1000, # Adjust as needed + 'overlap': 200, # Adjust as needed + 'language': 'english' # Add language detection if needed + } + + # Use the chunk_xml function to get structured chunks + chunks = chunk_xml(ET.tostring(root, encoding='unicode'), chunk_options) + + # Convert chunks to segments format expected by add_media_to_database + segments = [] + for chunk in chunks: + segment = { + 'Text': chunk['text'], + 'metadata': chunk['metadata'] # Preserve XML structure metadata + } + segments.append(segment) + + # Create info_dict + info_dict = { + 'title': title or 'Untitled XML Document', + 'uploader': author or 'Unknown', + 'file_type': 'xml', + 'structure': root.tag # Save root element type + } + + # Process keywords + keyword_list = [kw.strip() for kw in keywords.split(',') if kw.strip()] if keywords else [] + + # Handle summarization + if auto_summarize and api_name and api_key: + # Combine all chunks for summarization + full_text = '\n'.join(chunk['text'] for chunk in chunks) + summary = perform_summarization(api_name, full_text, custom_prompt, api_key) + else: + summary = "No summary provided" + + # Add to database + result = add_media_to_database( + url=import_file.name, # Using filename as URL + info_dict=info_dict, + segments=segments, + summary=summary, + keywords=keyword_list, + custom_prompt_input=custom_prompt, + whisper_model="XML Import", + media_type="xml_document", + overwrite=False + ) + + return f"XML file '{import_file.name}' import complete. Database result: {result}" + + except ET.ParseError as e: + logging.error(f"XML parsing error: {str(e)}") + return f"Error parsing XML file: {str(e)}" + except Exception as e: + logging.error(f"Error processing XML file: {str(e)}") + return f"Error processing XML file: {str(e)}" + +# +# End of XML_Ingestion_Lib.py +####################################################################################################################### diff --git a/App_Function_Libraries/Plaintext/__init__.py b/App_Function_Libraries/Plaintext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py index caba043057a43394c829f251a7535e745b2de7ef..d190355a96547728b570433d3b28deb3addb2fbf 100644 --- a/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py +++ b/App_Function_Libraries/Prompt_Engineering/Prompt_Engineering.py @@ -4,7 +4,7 @@ # Imports import re -from App_Function_Libraries.Chat import chat_api_call +from App_Function_Libraries.Chat.Chat_Functions import chat_api_call # # Local Imports # diff --git a/App_Function_Libraries/Prompt_Engineering/__Init__.py b/App_Function_Libraries/Prompt_Engineering/__Init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/App_Function_Libraries/RAG/ChromaDB_Library.py b/App_Function_Libraries/RAG/ChromaDB_Library.py index beb3aa80897b21f14b34927dd78c5eb87e2c1fae..00ad7d156a38d548154be5363abacddfd01c943d 100644 --- a/App_Function_Libraries/RAG/ChromaDB_Library.py +++ b/App_Function_Libraries/RAG/ChromaDB_Library.py @@ -49,36 +49,37 @@ embedding_api_url = config.get('Embeddings', 'api_url', fallback='') # Function to preprocess and store all existing content in the database -def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): - unprocessed_media = get_unprocessed_media(db=database) - total_media = len(unprocessed_media) - - for index, row in enumerate(unprocessed_media, 1): - media_id, content, media_type, file_name = row - collection_name = f"{media_type}_{media_id}" - - logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") - - try: - process_and_store_content( - database=database, - content=content, - collection_name=collection_name, - media_id=media_id, - file_name=file_name or f"{media_type}_{media_id}", - create_embeddings=True, - create_contextualized=create_contextualized, - api_name=api_name - ) - - # Mark the media as processed in the database - mark_media_as_processed(database, media_id) - - logger.info(f"Successfully processed media ID {media_id}") - except Exception as e: - logger.error(f"Error processing media ID {media_id}: {str(e)}") - - logger.info("Finished preprocessing all unprocessed content") +# FIXME - Deprecated +# def preprocess_all_content(database, create_contextualized=True, api_name="gpt-3.5-turbo"): +# unprocessed_media = get_unprocessed_media(db=database) +# total_media = len(unprocessed_media) +# +# for index, row in enumerate(unprocessed_media, 1): +# media_id, content, media_type, file_name = row +# collection_name = f"{media_type}_{media_id}" +# +# logger.info(f"Processing media {index} of {total_media}: ID {media_id}, Type {media_type}") +# +# try: +# process_and_store_content( +# database=database, +# content=content, +# collection_name=collection_name, +# media_id=media_id, +# file_name=file_name or f"{media_type}_{media_id}", +# create_embeddings=True, +# create_contextualized=create_contextualized, +# api_name=api_name +# ) +# +# # Mark the media as processed in the database +# mark_media_as_processed(database, media_id) +# +# logger.info(f"Successfully processed media ID {media_id}") +# except Exception as e: +# logger.error(f"Error processing media ID {media_id}: {str(e)}") +# +# logger.info("Finished preprocessing all unprocessed content") def batched(iterable, n): @@ -233,7 +234,10 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids logging.info(f"Number of embeddings: {len(embeddings)}, Dimension: {embedding_dim}") try: - # Attempt to get or create the collection + # Clean metadata + cleaned_metadatas = [clean_metadata(metadata) for metadata in metadatas] + + # Try to get or create the collection try: collection = chroma_client.get_collection(name=collection_name) logging.info(f"Existing collection '{collection_name}' found") @@ -258,7 +262,7 @@ def store_in_chroma(collection_name: str, texts: List[str], embeddings: Any, ids documents=texts, embeddings=embeddings, ids=ids, - metadatas=metadatas + metadatas=cleaned_metadatas ) logging.info(f"Successfully upserted {len(embeddings)} embeddings") @@ -290,12 +294,19 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st # Fetch a sample of embeddings to check metadata sample_results = collection.get(limit=10, include=["metadatas"]) - if not sample_results['metadatas']: - raise ValueError("No metadata found in the collection") + if not sample_results.get('metadatas') or not any(sample_results['metadatas']): + logging.warning(f"No metadata found in the collection '{collection_name}'. Skipping this collection.") + return [] # Check if all embeddings use the same model and provider - embedding_models = [metadata.get('embedding_model') for metadata in sample_results['metadatas'] if metadata.get('embedding_model')] - embedding_providers = [metadata.get('embedding_provider') for metadata in sample_results['metadatas'] if metadata.get('embedding_provider')] + embedding_models = [ + metadata.get('embedding_model') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_model') + ] + embedding_providers = [ + metadata.get('embedding_provider') for metadata in sample_results['metadatas'] + if metadata and metadata.get('embedding_provider') + ] if not embedding_models or not embedding_providers: raise ValueError("Embedding model or provider information not found in metadata") @@ -319,13 +330,13 @@ def vector_search(collection_name: str, query: str, k: int = 10) -> List[Dict[st ) if not results['documents'][0]: - logging.warning("No results found for the query") + logging.warning(f"No results found for the query in collection '{collection_name}'.") return [] return [{"content": doc, "metadata": meta} for doc, meta in zip(results['documents'][0], results['metadatas'][0])] except Exception as e: - logging.error(f"Error in vector_search: {str(e)}", exc_info=True) - raise + logging.error(f"Error in vector_search for collection '{collection_name}': {str(e)}", exc_info=True) + return [] def schedule_embedding(media_id: int, content: str, media_name: str): @@ -350,6 +361,21 @@ def schedule_embedding(media_id: int, content: str, media_name: str): logging.error(f"Error scheduling embedding for media_id {media_id}: {str(e)}") +def clean_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + """Clean metadata by removing None values and converting to appropriate types""" + cleaned = {} + for key, value in metadata.items(): + if value is not None: # Skip None values + if isinstance(value, (str, int, float, bool)): + cleaned[key] = value + elif isinstance(value, (np.int32, np.int64)): + cleaned[key] = int(value) + elif isinstance(value, (np.float32, np.float64)): + cleaned[key] = float(value) + else: + cleaned[key] = str(value) # Convert other types to string + return cleaned + # Function to process content, create chunks, embeddings, and store in ChromaDB and SQLite # def process_and_store_content(content: str, collection_name: str, media_id: int): # # Process the content into chunks diff --git a/App_Function_Libraries/RAG/Embeddings_Create.py b/App_Function_Libraries/RAG/Embeddings_Create.py index 2c732ec2a70bb92bb313eed43920c720cae23330..918c99af0e95639f1cff56eea087b4978317f026 100644 --- a/App_Function_Libraries/RAG/Embeddings_Create.py +++ b/App_Function_Libraries/RAG/Embeddings_Create.py @@ -25,8 +25,6 @@ from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histo # # Functions: -# FIXME - Version 2 - # Load configuration loaded_config = load_comprehensive_config() embedding_provider = loaded_config['Embeddings']['embedding_provider'] @@ -331,177 +329,6 @@ def create_openai_embedding(text: str, model: str) -> List[float]: return embedding -# FIXME - Version 1 -# # FIXME - Add all globals to summarize.py -# loaded_config = load_comprehensive_config() -# embedding_provider = loaded_config['Embeddings']['embedding_provider'] -# embedding_model = loaded_config['Embeddings']['embedding_model'] -# embedding_api_url = loaded_config['Embeddings']['embedding_api_url'] -# embedding_api_key = loaded_config['Embeddings']['embedding_api_key'] -# -# # Embedding Chunking Settings -# chunk_size = loaded_config['Embeddings']['chunk_size'] -# overlap = loaded_config['Embeddings']['overlap'] -# -# -# # FIXME - Add logging -# -# class HuggingFaceEmbedder: -# def __init__(self, model_name, timeout_seconds=120): # Default timeout of 2 minutes -# self.model_name = model_name -# self.tokenizer = None -# self.model = None -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# self.timeout_seconds = timeout_seconds -# self.last_used_time = 0 -# self.unload_timer = None -# -# def load_model(self): -# if self.model is None: -# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) -# self.model = AutoModel.from_pretrained(self.model_name) -# self.model.to(self.device) -# self.last_used_time = time.time() -# self.reset_timer() -# -# def unload_model(self): -# if self.model is not None: -# del self.model -# del self.tokenizer -# if torch.cuda.is_available(): -# torch.cuda.empty_cache() -# self.model = None -# self.tokenizer = None -# if self.unload_timer: -# self.unload_timer.cancel() -# -# def reset_timer(self): -# if self.unload_timer: -# self.unload_timer.cancel() -# self.unload_timer = Timer(self.timeout_seconds, self.unload_model) -# self.unload_timer.start() -# -# def create_embeddings(self, texts): -# self.load_model() -# inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) -# inputs = {k: v.to(self.device) for k, v in inputs.items()} -# with torch.no_grad(): -# outputs = self.model(**inputs) -# embeddings = outputs.last_hidden_state.mean(dim=1) -# return embeddings.cpu().numpy() -# -# # Global variable to hold the embedder -# huggingface_embedder = None -# -# -# class RateLimiter: -# def __init__(self, max_calls, period): -# self.max_calls = max_calls -# self.period = period -# self.calls = [] -# self.lock = Lock() -# -# def __call__(self, func): -# def wrapper(*args, **kwargs): -# with self.lock: -# now = time.time() -# self.calls = [call for call in self.calls if call > now - self.period] -# if len(self.calls) >= self.max_calls: -# sleep_time = self.calls[0] - (now - self.period) -# time.sleep(sleep_time) -# self.calls.append(time.time()) -# return func(*args, **kwargs) -# return wrapper -# -# -# def exponential_backoff(max_retries=5, base_delay=1): -# def decorator(func): -# @wraps(func) -# def wrapper(*args, **kwargs): -# for attempt in range(max_retries): -# try: -# return func(*args, **kwargs) -# except Exception as e: -# if attempt == max_retries - 1: -# raise -# delay = base_delay * (2 ** attempt) -# logging.warning(f"Attempt {attempt + 1} failed. Retrying in {delay} seconds. Error: {str(e)}") -# time.sleep(delay) -# return wrapper -# return decorator -# -# -# # FIXME - refactor/setup to use config file & perform chunking -# @exponential_backoff() -# @RateLimiter(max_calls=50, period=60) -# def create_embeddings_batch(texts: List[str], provider: str, model: str, api_url: str, timeout_seconds: int = 300) -> List[List[float]]: -# global embedding_models -# -# try: -# if provider.lower() == 'huggingface': -# if model not in embedding_models: -# if model == "dunzhang/stella_en_400M_v5": -# embedding_models[model] = ONNXEmbedder(model, model_dir, timeout_seconds) -# else: -# embedding_models[model] = HuggingFaceEmbedder(model, timeout_seconds) -# embedder = embedding_models[model] -# return embedder.create_embeddings(texts) -# -# elif provider.lower() == 'openai': -# logging.debug(f"Creating embeddings for {len(texts)} texts using OpenAI API") -# return [create_openai_embedding(text, model) for text in texts] -# -# elif provider.lower() == 'local': -# response = requests.post( -# api_url, -# json={"texts": texts, "model": model}, -# headers={"Authorization": f"Bearer {embedding_api_key}"} -# ) -# if response.status_code == 200: -# return response.json()['embeddings'] -# else: -# raise Exception(f"Error from local API: {response.text}") -# else: -# raise ValueError(f"Unsupported embedding provider: {provider}") -# except Exception as e: -# logging.error(f"Error in create_embeddings_batch: {str(e)}") -# raise -# -# def create_embedding(text: str, provider: str, model: str, api_url: str) -> List[float]: -# return create_embeddings_batch([text], provider, model, api_url)[0] -# -# -# def create_openai_embedding(text: str, model: str) -> List[float]: -# embedding = get_openai_embeddings(text, model) -# return embedding -# -# -# # FIXME - refactor to use onnx embeddings callout -# def create_stella_embeddings(text: str) -> List[float]: -# if embedding_provider == 'local': -# # Load the model and tokenizer -# tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_400M_v5") -# model = AutoModel.from_pretrained("dunzhang/stella_en_400M_v5") -# -# # Tokenize and encode the text -# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) -# -# # Generate embeddings -# with torch.no_grad(): -# outputs = model(**inputs) -# -# # Use the mean of the last hidden state as the sentence embedding -# embeddings = outputs.last_hidden_state.mean(dim=1) -# -# return embeddings[0].tolist() # Convert to list for consistency -# elif embedding_provider == 'openai': -# return get_openai_embeddings(text, embedding_model) -# else: -# raise ValueError(f"Unsupported embedding provider: {embedding_provider}") -# # -# # End of F -# ############################################################## -# # # ############################################################## # # diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 68a0c8b91361b904f1843893c8afd9848450ed8c..d903abffe578337aa40589807c2a7808f7972220 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -9,14 +9,16 @@ import time from typing import Dict, Any, List, Optional from App_Function_Libraries.DB.Character_Chat_DB import get_character_chats, perform_full_text_search_chat, \ - fetch_keywords_for_chats + fetch_keywords_for_chats, search_character_chat, search_character_cards, fetch_character_ids_by_keywords +from App_Function_Libraries.DB.RAG_QA_Chat_DB import search_rag_chat, search_rag_notes # # Local Imports from App_Function_Libraries.RAG.ChromaDB_Library import process_and_store_content, vector_search, chroma_client from App_Function_Libraries.RAG.RAG_Persona_Chat import perform_vector_search_chat from App_Function_Libraries.Summarization.Local_Summarization_Lib import summarize_with_custom_openai from App_Function_Libraries.Web_Scraping.Article_Extractor_Lib import scrape_article -from App_Function_Libraries.DB.DB_Manager import search_db, fetch_keywords_for_media +from App_Function_Libraries.DB.DB_Manager import fetch_keywords_for_media, search_media_db, get_notes_by_keywords, \ + search_conversations_by_keywords from App_Function_Libraries.Utils.Utils import load_comprehensive_config from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -40,6 +42,15 @@ config = configparser.ConfigParser() # Read the configuration file config.read('config.txt') + +search_functions = { + "Media DB": search_media_db, + "RAG Chat": search_rag_chat, + "RAG Notes": search_rag_notes, + "Character Chat": search_character_chat, + "Character Cards": search_character_cards +} + # RAG pipeline function for web scraping # def rag_web_scraping_pipeline(url: str, query: str, api_choice=None) -> Dict[str, Any]: # try: @@ -117,7 +128,20 @@ config.read('config.txt') # RAG Search with keyword filtering # FIXME - Update each called function to support modifiable top-k results -def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top_k=10, apply_re_ranking=True) -> Dict[str, Any]: +def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts_top_k=10, apply_re_ranking=True, database_types: List[str] = "Media DB") -> Dict[str, Any]: + """ + Perform full text search across specified database type. + + Args: + query: Search query string + api_choice: API to use for generating the response + fts_top_k: Maximum number of results to return + keywords: Optional list of media IDs to filter results + database_types: Type of database to search ("Media DB", "RAG Chat", or "Character Chat") + + Returns: + Dictionary containing search results with content + """ log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice}) start_time = time.time() try: @@ -131,16 +155,97 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top keyword_list = [k.strip().lower() for k in keywords.split(',')] if keywords else [] logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}") - # Fetch relevant media IDs based on keywords if keywords are provided - relevant_media_ids = fetch_relevant_media_ids(keyword_list) if keyword_list else None - logging.debug(f"\n\nenhanced_rag_pipeline - relevant media IDs: {relevant_media_ids}") + relevant_ids = {} + + # Fetch relevant IDs based on keywords if keywords are provided + if keyword_list: + try: + for db_type in database_types: + if db_type == "Media DB": + relevant_media_ids = fetch_relevant_media_ids(keyword_list) + relevant_ids[db_type] = relevant_media_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}") + + elif db_type == "RAG Chat": + conversations, total_pages, total_count = search_conversations_by_keywords( + keywords=keyword_list) + relevant_conversation_ids = [conv['conversation_id'] for conv in conversations] + relevant_ids[db_type] = relevant_conversation_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}") + + elif db_type == "RAG Notes": + notes, total_pages, total_count = get_notes_by_keywords(keyword_list) + relevant_note_ids = [note_id for note_id, _, _, _ in notes] # Unpack note_id from the tuple + relevant_ids[db_type] = relevant_note_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}") + + elif db_type == "Character Chat": + relevant_chat_ids = fetch_keywords_for_chats(keyword_list) + relevant_ids[db_type] = relevant_chat_ids + logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}") + + elif db_type == "Character Cards": + # Assuming we have a function to fetch character IDs by keywords + relevant_character_ids = fetch_character_ids_by_keywords(keyword_list) + relevant_ids[db_type] = relevant_character_ids + logging.debug( + f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}") + + else: + logging.error(f"Unsupported database type: {db_type}") + + except Exception as e: + logging.error(f"Error fetching relevant IDs: {str(e)}") + else: + relevant_ids = None + + # Extract relevant media IDs for each selected DB + # Prepare a dict to hold relevant_media_ids per DB + relevant_media_ids_dict = {} + if relevant_ids: + for db_type in database_types: + relevant_media_ids = relevant_ids.get(db_type, None) + if relevant_media_ids: + # Convert to List[str] if not None + relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids] + else: + relevant_media_ids_dict[db_type] = None + else: + relevant_media_ids_dict = {db_type: None for db_type in database_types} + + # Perform vector search for all selected databases + vector_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_media_ids_dict.get(db_type) + results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k) + vector_results.extend(results) + logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}") + except Exception as e: + logging.error(f"Error performing vector search on {db_type}: {str(e)}") # Perform vector search + # FIXME vector_results = perform_vector_search(query, relevant_media_ids) logging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}") # Perform full-text search - fts_results = perform_full_text_search(query, relevant_media_ids) + #v1 + #fts_results = perform_full_text_search(query, database_type, relevant_media_ids, fts_top_k) + + # v2 + # Perform full-text search across specified databases + fts_results = [] + for db_type in database_types: + try: + db_relevant_ids = relevant_ids.get(db_type) if relevant_ids else None + db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k) + fts_results.extend(db_results) + logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}") + except Exception as e: + logging.error(f"Error performing full-text search on {db_type}: {str(e)}") + logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:") logging.debug( "\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join( @@ -175,8 +280,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top # Update all_results based on reranking all_results = [all_results[result['id']] for result in reranked_results] - # Extract content from results (top 10 by default) - context = "\n".join([result['content'] for result in all_results[:top_k]]) + # Extract content from results (top fts_top_k by default) + context = "\n".join([result['content'] for result in all_results[:fts_top_k]]) logging.debug(f"Context length: {len(context)}") logging.debug(f"Context: {context[:200]}") @@ -208,6 +313,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top "context": "" } + + # Need to write a test for this function FIXME def generate_answer(api_choice: str, context: str, query: str) -> str: # Metrics @@ -336,6 +443,7 @@ def generate_answer(api_choice: str, context: str, query: str) -> str: logging.error(f"Error in generate_answer: {str(e)}") return "An error occurred while generating the answer." + def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]: log_counter("perform_vector_search_attempt") start_time = time.time() @@ -344,6 +452,8 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ try: for collection in all_collections: collection_results = vector_search(collection.name, query, k=top_k) + if not collection_results: + continue # Skip empty results filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids @@ -358,29 +468,75 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ logging.error(f"Error in perform_vector_search: {str(e)}") raise -def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: - log_counter("perform_full_text_search_attempt") + +# V2 +def perform_full_text_search(query: str, database_type: str, relevant_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: + """ + Perform full-text search on a specified database type. + + Args: + query: Search query string + database_type: Type of database to search ("Media DB", "RAG Chat", "RAG Notes", "Character Chat", "Character Cards") + relevant_ids: Optional list of media IDs to filter results + fts_top_k: Maximum number of results to return + + Returns: + List of search results with content and metadata + """ + log_counter("perform_full_text_search_attempt", labels={"database_type": database_type}) start_time = time.time() + try: - fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) - filtered_fts_results = [ - { - "content": result['content'], - "metadata": {"media_id": result['id']} - } - for result in fts_results - if relevant_media_ids is None or result['id'] in relevant_media_ids - ] + # Set default for fts_top_k + if fts_top_k is None: + fts_top_k = 10 + + # Call appropriate search function based on database type + if database_type not in search_functions: + raise ValueError(f"Unsupported database type: {database_type}") + + # Call the appropriate search function + results = search_functions[database_type](query, fts_top_k, relevant_ids) + search_duration = time.time() - start_time - log_histogram("perform_full_text_search_duration", search_duration) - log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) - return filtered_fts_results + log_histogram("perform_full_text_search_duration", search_duration, + labels={"database_type": database_type}) + log_counter("perform_full_text_search_success", + labels={"database_type": database_type, "result_count": len(results)}) + + return results + except Exception as e: - log_counter("perform_full_text_search_error", labels={"error": str(e)}) - logging.error(f"Error in perform_full_text_search: {str(e)}") + log_counter("perform_full_text_search_error", + labels={"database_type": database_type, "error": str(e)}) + logging.error(f"Error in perform_full_text_search ({database_type}): {str(e)}") raise +# v1 +# def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: +# log_counter("perform_full_text_search_attempt") +# start_time = time.time() +# try: +# fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) +# filtered_fts_results = [ +# { +# "content": result['content'], +# "metadata": {"media_id": result['id']} +# } +# for result in fts_results +# if relevant_media_ids is None or result['id'] in relevant_media_ids +# ] +# search_duration = time.time() - start_time +# log_histogram("perform_full_text_search_duration", search_duration) +# log_counter("perform_full_text_search_success", labels={"result_count": len(filtered_fts_results)}) +# return filtered_fts_results +# except Exception as e: +# log_counter("perform_full_text_search_error", labels={"error": str(e)}) +# logging.error(f"Error in perform_full_text_search: {str(e)}") +# raise + + def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]: log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)}) start_time = time.time() @@ -502,6 +658,7 @@ def enhanced_rag_pipeline_chat(query: str, api_choice: str, character_id: int, k logging.debug(f"enhanced_rag_pipeline_chat - Vector search results: {vector_results}") # Perform full-text search within the relevant chats + # FIXME - Update for DB Selection fts_results = perform_full_text_search_chat(query, relevant_chat_ids) logging.debug("enhanced_rag_pipeline_chat - Full-text search results:") logging.debug("\n".join([str(item) for item in fts_results])) diff --git a/App_Function_Libraries/RAG/RAG_QA_Chat.py b/App_Function_Libraries/RAG/RAG_QA_Chat.py index 440f3ae67946b3bc3849345011695be4cf6c3680..10a770b830b12f295036165ac4a3f908f829988e 100644 --- a/App_Function_Libraries/RAG/RAG_QA_Chat.py +++ b/App_Function_Libraries/RAG/RAG_QA_Chat.py @@ -12,7 +12,7 @@ import time from typing import List, Tuple, IO, Union # # Local Imports -from App_Function_Libraries.DB.DB_Manager import db, search_db, DatabaseError, get_media_content +from App_Function_Libraries.DB.DB_Manager import db, search_media_db, DatabaseError, get_media_content from App_Function_Libraries.RAG.RAG_Library_2 import generate_answer, enhanced_rag_pipeline from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram # @@ -89,7 +89,7 @@ def search_database(query: str) -> List[Tuple[int, str]]: log_counter("search_database_attempt") start_time = time.time() # Implement database search functionality - results = search_db(query, ["title", "content"], "", page=1, results_per_page=10) + results = search_media_db(query, ["title", "content"], "", page=1, results_per_page=10) search_duration = time.time() - start_time log_histogram("search_database_duration", search_duration) log_counter("search_database_success", labels={"result_count": len(results)}) diff --git a/App_Function_Libraries/Summarization/Summarization_General_Lib.py b/App_Function_Libraries/Summarization/Summarization_General_Lib.py index af137da9d78567fea626be8123824fb58d3ef3fb..03f2289fd52365907d4b422800362e902e1ec998 100644 --- a/App_Function_Libraries/Summarization/Summarization_General_Lib.py +++ b/App_Function_Libraries/Summarization/Summarization_General_Lib.py @@ -1111,79 +1111,144 @@ def process_video_urls(url_list, num_speakers, whisper_model, custom_prompt_inpu return current_progress, success_message, None, None, None, None - -def perform_transcription(video_path, offset, whisper_model, vad_filter, diarize=False): +def perform_transcription(video_path, offset, whisper_model, vad_filter, diarize=False, overwrite=False): temp_files = [] logging.info(f"Processing media: {video_path}") global segments_json_path audio_file_path = convert_to_wav(video_path, offset) logging.debug(f"Converted audio file: {audio_file_path}") temp_files.append(audio_file_path) - logging.debug("Replacing audio file with segments.json file") - segments_json_path = audio_file_path.replace('.wav', '.segments.json') + logging.debug("Setting up segments JSON path") + + # Update path to include whisper model in filename + base_path = audio_file_path.replace('.wav', '') + segments_json_path = f"{base_path}-whisper_model-{whisper_model}.segments.json" temp_files.append(segments_json_path) if diarize: - diarized_json_path = audio_file_path.replace('.wav', '.diarized.json') + diarized_json_path = f"{base_path}-whisper_model-{whisper_model}.diarized.json" - # Check if diarized JSON already exists + # Check if diarized JSON already exists and is valid if os.path.exists(diarized_json_path): logging.info(f"Diarized file already exists: {diarized_json_path}") try: - with open(diarized_json_path, 'r') as file: + with open(diarized_json_path, 'r', encoding='utf-8') as file: diarized_segments = json.load(file) - if not diarized_segments: - logging.warning(f"Diarized JSON file is empty, re-generating: {diarized_json_path}") - raise ValueError("Empty diarized JSON file") - logging.debug(f"Loaded diarized segments from {diarized_json_path}") + # Check if segments are empty or invalid + if not diarized_segments or not isinstance(diarized_segments, list): + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." + logging.warning(f"Diarized JSON file is empty or invalid, re-generating: {diarized_json_path}") + raise ValueError("Invalid diarized JSON file") + # Check if segments contain expected content + if not all('Text' in segment for segment in diarized_segments): + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." + logging.warning(f"Diarized segments missing required fields, re-generating: {diarized_json_path}") + raise ValueError("Invalid segment format") + logging.debug(f"Loaded valid diarized segments from {diarized_json_path}") return audio_file_path, diarized_segments except (json.JSONDecodeError, ValueError) as e: + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." logging.error(f"Failed to read or parse the diarized JSON file: {e}") - os.remove(diarized_json_path) + if os.path.exists(diarized_json_path): + os.remove(diarized_json_path) - # If diarized file doesn't exist or was corrupted, generate new diarized transcription + # Generate new diarized transcription logging.info(f"Generating diarized transcription for {audio_file_path}") diarized_segments = combine_transcription_and_diarization(audio_file_path) + # Validate diarized segments before saving + if not diarized_segments or not isinstance(diarized_segments, list): + logging.error("Generated diarized segments are empty or invalid") + return None, None + # Save diarized segments - with open(diarized_json_path, 'w') as file: - json.dump(diarized_segments, file, indent=2) + json_str = json.dumps(diarized_segments, indent=2) + with open(diarized_json_path, 'w', encoding='utf-8') as f: + f.write(json_str) return audio_file_path, diarized_segments # Non-diarized transcription - if os.path.exists(segments_json_path): - logging.info(f"Segments file already exists: {segments_json_path}") - try: - with open(segments_json_path, 'r') as file: - segments = json.load(file) - if not segments: - logging.warning(f"Segments JSON file is empty, re-generating: {segments_json_path}") - raise ValueError("Empty segments JSON file") - logging.debug(f"Loaded segments from {segments_json_path}") - except (json.JSONDecodeError, ValueError) as e: - logging.error(f"Failed to read or parse the segments JSON file: {e}") - os.remove(segments_json_path) - logging.info(f"Re-generating transcription for {audio_file_path}") - audio_file, segments = re_generate_transcription(audio_file_path, whisper_model, vad_filter) - if segments is None: - return None, None - else: + try: + # If segments file exists, try to load it + if os.path.exists(segments_json_path): + logging.info(f"Segments file already exists: {segments_json_path}") + try: + with open(segments_json_path, 'r', encoding='utf-8') as file: + segments = json.load(file) + # Check if segments are empty or invalid + if not segments or not isinstance(segments, list): + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." + raise ValueError("Invalid segments JSON file") + # Check if segments contain expected content + if not all( + isinstance(segment, dict) and all(key in segment for key in ['Text', 'Time_Start', 'Time_End']) + for segment in segments): + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." + raise ValueError("Invalid segment format") + logging.debug(f"Loaded valid segments from {segments_json_path}") + return audio_file_path, segments + except (json.JSONDecodeError, ValueError, KeyError) as e: + if not overwrite: + logging.info("Overwrite flag not set. Existing file not overwritten.") + return None, "Overwrite flag not set. Existing file not overwritten." + logging.error(f"Failed to read or parse the segments JSON file: {str(e)}") + if os.path.exists(segments_json_path): + os.remove(segments_json_path) + + # Generate new transcription if file doesn't exist audio_file, segments = re_generate_transcription(audio_file_path, whisper_model, vad_filter) + if segments is None: + logging.error("Failed to generate new transcription") + return None, None + + return audio_file_path, segments - return audio_file_path, segments + except Exception as e: + logging.error(f"Error in perform_transcription: {str(e)}") + return None, None def re_generate_transcription(audio_file_path, whisper_model, vad_filter): + global segments_json_path try: + logging.info(f"Generating new transcription for {audio_file_path}") segments = speech_to_text(audio_file_path, whisper_model=whisper_model, vad_filter=vad_filter) + + # Print the first few segments for debugging + logging.debug(f"First few segments from speech_to_text: {segments[:2] if segments else 'None'}") + + # Validate segments before saving + if not segments or not isinstance(segments, list): + logging.error("Generated segments are empty or invalid") + return None, None + + # More detailed validation + if not all(isinstance(segment, dict) and all(key in segment for key in ['Text', 'Time_Start', 'Time_End']) for + segment in segments): + logging.error("Generated segments are missing required fields or have invalid format") + logging.debug(f"Segments structure: {segments[:2]}") # Log first two segments for debugging + return None, None + # Save segments to JSON - with open(segments_json_path, 'w') as file: - json.dump(segments, file, indent=2) - logging.debug(f"Transcription segments saved to {segments_json_path}") + json_str = json.dumps(segments, indent=2) + with open(segments_json_path, 'w', encoding='utf-8') as f: + f.write(json_str) + + logging.debug(f"Valid transcription segments saved to {segments_json_path}") return audio_file_path, segments except Exception as e: - logging.error(f"Error in re-generating transcription: {str(e)}") + logging.error(f"Error in re_generate_transcription: {str(e)}") return None, None @@ -1191,17 +1256,40 @@ def save_transcription_and_summary(transcription_text, summary_text, download_pa try: video_title = sanitize_filename(info_dict.get('title', 'Untitled')) + # Handle different transcription_text formats + if isinstance(transcription_text, dict): + if 'transcription' in transcription_text: + # Handle the case where it's a dict with 'transcription' key + text_to_save = '\n'.join(segment['Text'] for segment in transcription_text['transcription']) + else: + # Handle other dictionary formats + text_to_save = str(transcription_text) + elif isinstance(transcription_text, list): + # Handle list of segments + text_to_save = '\n'.join(segment['Text'] for segment in transcription_text) + else: + # Handle string input + text_to_save = str(transcription_text) + + # Validate the extracted text + if not text_to_save or not text_to_save.strip(): + logging.error("Transcription text is empty or contains only whitespace") + return None, None + # Save transcription transcription_file_path = os.path.join(download_path, f"{video_title}_transcription.txt") with open(transcription_file_path, 'w', encoding='utf-8') as f: - f.write(transcription_text) + f.write(text_to_save) # Save summary if available summary_file_path = None if summary_text: - summary_file_path = os.path.join(download_path, f"{video_title}_summary.txt") - with open(summary_file_path, 'w', encoding='utf-8') as f: - f.write(summary_text) + if isinstance(summary_text, str) and summary_text.strip(): + summary_file_path = os.path.join(download_path, f"{video_title}_summary.txt") + with open(summary_file_path, 'w', encoding='utf-8') as f: + f.write(summary_text) + else: + logging.warning("Summary text is not a string or contains only whitespace") return transcription_file_path, summary_file_path except Exception as e: diff --git a/App_Function_Libraries/Third_Party/Anki.py b/App_Function_Libraries/Third_Party/Anki.py new file mode 100644 index 0000000000000000000000000000000000000000..49b28131929600fdc716139898f97fb88f3a98b4 --- /dev/null +++ b/App_Function_Libraries/Third_Party/Anki.py @@ -0,0 +1,640 @@ +# Anki.py +# Description: Functions for Anki card generation +# +# Imports +import json +import zipfile +import sqlite3 +import tempfile +import os +import shutil +import base64 +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, Tuple, Optional, Any, List +import re +from html.parser import HTMLParser +# +# External Imports +#from outlines import models, prompts +# Local Imports +# +############################################################################################################ +# +# Functions: + +class HTMLImageExtractor(HTMLParser): + """Extract and validate image tags from HTML content.""" + + def __init__(self): + super().__init__() + self.images = [] + + def handle_starttag(self, tag, attrs): + if tag == 'img': + attrs_dict = dict(attrs) + if 'src' in attrs_dict: + self.images.append(attrs_dict['src']) + + +def sanitize_html(content: str) -> str: + """Sanitize HTML content while preserving valid image tags and basic formatting.""" + if not content: + return "" + + # Allow basic formatting and image tags + allowed_tags = {'img', 'b', 'i', 'u', 'div', 'br', 'p', 'span'} + allowed_attrs = {'src', 'alt', 'class', 'style'} + + # Remove potentially harmful attributes + content = re.sub(r'(on\w+)="[^"]*"', '', content) + content = re.sub(r'javascript:', '', content) + + # Parse and rebuild HTML + parser = HTMLParser() + parser.feed(content) + return content + + +def extract_media_from_apkg(zip_path: Any, temp_dir: str) -> Dict[str, str]: + """Extract and process media files from APKG.""" + media_files = {} + try: + # Handle file path whether it's a string or file object + if hasattr(zip_path, 'name'): + # It's a file object from Gradio + file_name = zip_path.name + else: + # It's a string path + file_name = str(zip_path) + + with zipfile.ZipFile(file_name, 'r') as zip_ref: + if 'media' in zip_ref.namelist(): + media_json = json.loads(zip_ref.read('media').decode('utf-8')) + + for file_id, filename in media_json.items(): + if str(file_id) in zip_ref.namelist(): + file_data = zip_ref.read(str(file_id)) + file_path = os.path.join(temp_dir, filename) + + # Save file temporarily + with open(file_path, 'wb') as f: + f.write(file_data) + + # Process supported image types + if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif']): + try: + with open(file_path, 'rb') as f: + file_content = f.read() + file_ext = os.path.splitext(filename)[1].lower() + media_type = f"image/{file_ext[1:]}" + if file_ext == '.jpg': + media_type = "image/jpeg" + media_files[ + filename] = f"data:{media_type};base64,{base64.b64encode(file_content).decode('utf-8')}" + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + + # Clean up temporary file + os.remove(file_path) + + except Exception as e: + print(f"Error processing media: {str(e)}") + return media_files + + +def validate_card_content(card: Dict[str, Any], seen_ids: set) -> list: + """Validate individual card content and structure.""" + issues = [] + + # Check required fields + if 'id' not in card: + issues.append("Missing ID") + elif card['id'] in seen_ids: + issues.append("Duplicate ID") + else: + seen_ids.add(card['id']) + + if 'type' not in card or card['type'] not in ['basic', 'cloze', 'reverse']: + issues.append("Invalid card type") + + if 'front' not in card or not card['front'].strip(): + issues.append("Missing front content") + + if 'back' not in card or not card['back'].strip(): + issues.append("Missing back content") + + if 'tags' not in card or not card['tags']: + issues.append("Missing tags") + + # Content-specific validation + if card.get('type') == 'cloze': + if '{{c1::' not in card['front']: + issues.append("Invalid cloze format") + + # Image validation + for field in ['front', 'back']: + if ' Tuple[Optional[Dict], Optional[Dict], str]: + """Process APKG file with support for different Anki database versions.""" + if not file_path: + return None, None, "No file provided" + # Handle file path whether it's a string or file object + if hasattr(file_path, 'name'): + # It's a file object from Gradio + file_name = file_path.name + else: + # It's a string path + file_name = str(file_path) + + temp_dir = None + db_conn = None + cursor = None + cards_data = {"cards": []} + deck_info = None + + try: + # Create temporary directory + temp_dir = tempfile.mkdtemp() + + # Extract media files first + media_files = extract_media_from_apkg(file_name, temp_dir) + + # Extract APKG contents + with zipfile.ZipFile(file_name, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + zip_ref.extractall(temp_dir) + + db_path = os.path.join(temp_dir, 'collection.anki2') + + # Process database with explicit connection management + db_conn = sqlite3.connect(db_path) + cursor = db_conn.cursor() + + try: + # Get collection info + cursor.execute("SELECT decks, models FROM col") + decks_json, models_json = cursor.fetchone() + deck_info = { + "decks": json.loads(decks_json), + "models": json.loads(models_json) + } + + # Check if we're dealing with an older or newer Anki version + try: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='notetypes'") + has_notetypes = cursor.fetchone() is not None + + if has_notetypes: + # New Anki version (2.1.28+) + cursor.execute(""" + SELECT + n.id, n.flds, n.tags, c.type, n.mid, + m.name, n.sfld, m.flds, m.tmpls + FROM notes n + JOIN cards c ON c.nid = n.id + JOIN notetypes m ON m.id = n.mid + """) + else: + # Older Anki version + cursor.execute(""" + SELECT + n.id, n.flds, n.tags, c.type, n.mid, + m.name, n.sfld, m.flds, m.tmpls + FROM notes n + JOIN cards c ON c.nid = n.id + JOIN col AS m ON m.id = 1 AND json_extract(m.models, '$.' || n.mid) IS NOT NULL + """) + + rows = cursor.fetchall() + + except sqlite3.Error as e: + # Fallback query for very old Anki versions + cursor.execute(""" + SELECT + n.id, n.flds, n.tags, c.type, n.mid, + '', n.sfld, '[]', '[]' + FROM notes n + JOIN cards c ON c.nid = n.id + """) + rows = cursor.fetchall() + + finally: + cursor.close() + db_conn.close() + + # Process the fetched data + for row in rows: + note_id, fields, tags, card_type, model_id = row[0:5] + model_name = row[5] if row[5] else "Unknown Model" + fields_list = fields.split('\x1f') + + try: + fields_config = json.loads(row[7]) if row[7] else [] + templates = json.loads(row[8]) if row[8] else [] + except json.JSONDecodeError: + fields_config = [] + templates = [] + + # Process fields with media + processed_fields = [] + for field in fields_list: + field_html = field + for filename, base64_data in media_files.items(): + field_html = field_html.replace( + f' 1 else "", + "tags": tags.strip().split(" ") if tags.strip() else ["imported"], + "note": f"Imported from deck: {model_name}", + "has_media": any(' Tuple[bool, str]: + """Validate flashcard content with enhanced image support.""" + try: + data = json.loads(content) + validation_results = [] + is_valid = True + + if not isinstance(data, dict) or 'cards' not in data: + return False, "Invalid JSON format. Must contain 'cards' array." + + seen_ids = set() + for idx, card in enumerate(data['cards']): + card_issues = validate_card_content(card, seen_ids) + + if card_issues: + is_valid = False + validation_results.append(f"Card {card['id']}: {', '.join(card_issues)}") + + return is_valid, "\n".join(validation_results) if validation_results else "All cards are valid!" + + except json.JSONDecodeError: + return False, "Invalid JSON format" + except Exception as e: + return False, f"Validation error: {str(e)}" + + +def enhanced_file_upload(file: Any, input_type: str) -> Tuple[Optional[str], Optional[Dict], str, List[str]]: + """Enhanced file upload handler with better error handling.""" + if not file: + return None, None, "❌ No file uploaded", [] + + try: + if input_type == "APKG": + cards_data, deck_info, message = process_apkg_file(file) + if cards_data: + content = json.dumps(cards_data, indent=2) + choices = update_card_choices(content) + # Validate the converted content + validation_msg = handle_validation(content, "APKG") + return content, deck_info, validation_msg, choices + return None, None, f"❌ {message}", [] + else: + # Original JSON file handling + content = file.read().decode('utf-8') + json.loads(content) # Validate JSON + return content, None, "✅ JSON file loaded successfully!", update_card_choices(content) + except Exception as e: + return None, None, f"❌ Error processing file: {str(e)}", [] + +def handle_file_upload(file: Any, input_type: str) -> Tuple[Optional[str], Optional[Dict], str, List[str]]: + """Handle file upload with proper validation message formatting and card choices update.""" + if not file: + return None, None, "❌ No file uploaded", [] + + if input_type == "APKG": + cards_data, deck_info, message = process_apkg_file(file) + if cards_data: + content = json.dumps(cards_data, indent=2) + return ( + content, + deck_info, + f"✅ {message}", + update_card_choices(content) + ) + return None, None, f"❌ {message}", [] + else: # JSON + try: + content = file.read().decode('utf-8') + json.loads(content) # Validate JSON + return ( + content, + None, + "✅ JSON file loaded successfully!", + update_card_choices(content) + ) + except Exception as e: + return None, None, f"❌ Error loading JSON file: {str(e)}", [] + +def update_card_choices(content: str) -> List[str]: + """Update card choices for the dropdown.""" + try: + data = json.loads(content) + return [f"{card['id']} - {card['front'][:50]}..." for card in data['cards']] + except: + return [] + + +def update_card_content( + current_content: str, + card_id: str, + card_type: str, + front: str, + back: str, + tags: str, + notes: str +) -> Tuple[str, str]: + """Update card content and return updated JSON and status message.""" + try: + data = json.loads(current_content) + + for card in data['cards']: + if card['id'] == card_id: + # Sanitize input content + card['type'] = card_type + card['front'] = sanitize_html(front) + card['back'] = sanitize_html(back) + card['tags'] = [tag.strip() for tag in tags.split(',')] + card['note'] = notes + + # Update media status + card['has_media'] = ' tuple: + """ + Load a card for editing and generate previews. + + Args: + card_selection (str): Selected card ID and preview text + current_content (str): Current JSON content + + Returns: + tuple: (card_type, front_content, back_content, tags, notes, front_preview, back_preview) + """ + if not card_selection or not current_content: + return "basic", "", "", "", "", "", "" + + try: + data = json.loads(current_content) + selected_id = card_selection.split(" - ")[0] + + for card in data['cards']: + if card['id'] == selected_id: + # Return all required fields with preview content + return ( + card['type'], + card['front'], + card['back'], + ", ".join(card['tags']), + card.get('note', ''), + sanitize_html(card['front']), + sanitize_html(card['back']) + ) + + return "basic", "", "", "", "", "", "" + + except Exception as e: + print(f"Error loading card: {str(e)}") + return "basic", "", "", "", "", "", "" + + +def export_cards(content: str, format_type: str) -> Tuple[str, Optional[Tuple[str, str, str]]]: + """Export cards in the specified format.""" + try: + is_valid, validation_message = validate_flashcards(content) + if not is_valid: + return "Please fix validation issues before exporting.", None + + data = json.loads(content) + + if format_type == "Anki CSV": + output = "Front,Back,Tags,Type,Note\n" + for card in data['cards']: + output += f'"{card["front"]}","{card["back"]}","{" ".join(card["tags"])}","{card["type"]}","{card.get("note", "")}"\n' + return "Cards exported successfully!", ("anki_cards.csv", output, "text/csv") + + elif format_type == "JSON": + return "Cards exported successfully!", ("anki_cards.json", content, "application/json") + + else: # Plain Text + output = "" + for card in data['cards']: + # Replace image tags with placeholders + front = re.sub(r']+>', '[IMG]', card['front']) + back = re.sub(r']+>', '[IMG]', card['back']) + output += f"Q: {front}\nA: {back}\nTags: {', '.join(card['tags'])}\n\n" + return "Cards exported successfully!", ("anki_cards.txt", output, "text/plain") + + except Exception as e: + return f"Export error: {str(e)}", None + + +def generate_card_choices(content: str) -> list: + """Generate choices for card selector dropdown.""" + try: + data = json.loads(content) + return [f"{card['id']} - {card['front'][:50]}..." for card in data['cards']] + except: + return [] + +def format_validation_result(content: str) -> str: + """Format validation results for display in Markdown component.""" + try: + is_valid, message = validate_flashcards(content) + return f"✅ {message}" if is_valid else f"❌ {message}" + except Exception as e: + return f"❌ Error during validation: {str(e)}" + + +def validate_for_ui(content: str) -> str: + """Validate flashcards and return a formatted string for UI display.""" + if not content or not content.strip(): + return "❌ No content to validate. Please enter some flashcard data." + + try: + # First try to parse the JSON + try: + data = json.loads(content) + except json.JSONDecodeError as je: + # Provide more specific JSON error feedback + line_col = f" (line {je.lineno}, column {je.colno})" if hasattr(je, 'lineno') else "" + return f"❌ Invalid JSON format: {str(je)}{line_col}" + + # Check basic structure + if not isinstance(data, dict): + return "❌ Invalid format: Root element must be a JSON object" + + if "cards" not in data: + return '❌ Invalid format: Missing "cards" array in root object' + + if not isinstance(data["cards"], list): + return '❌ Invalid format: "cards" must be an array' + + if not data["cards"]: + return "❌ No cards found in the data" + + # If we get here, perform the full validation + is_valid, message = validate_flashcards(content) + if is_valid: + return f"✅ {message}" + else: + return f"❌ {message}" + + except Exception as e: + return f"❌ Validation error: {str(e)}" + + +def update_card_with_validation( + current_content: str, + card_selection: str, + card_type: str, + front: str, + back: str, + tags: str, + notes: str +) -> Tuple[str, str, List[str]]: + """Update card and return properly formatted validation message and updated choices.""" + try: + # Unpack the tuple returned by update_card_content + updated_content, message = update_card_content( + current_content, + card_selection.split(" - ")[0], + card_type, + front, + back, + tags, + notes + ) + + if "successfully" in message: + return ( + updated_content, + f"✅ {message}", + update_card_choices(updated_content) + ) + else: + return ( + current_content, + f"❌ {message}", + update_card_choices(current_content) + ) + except Exception as e: + return ( + current_content, + f"❌ Error updating card: {str(e)}", + update_card_choices(current_content) + ) + + +def handle_validation(content: str, input_format: str) -> str: + """Handle validation for both JSON and APKG formats.""" + if not content or not content.strip(): + return "❌ No content to validate" + + try: + data = json.loads(content) + + if not isinstance(data, dict): + return "❌ Invalid format: Root element must be a JSON object" + + if "cards" not in data: + return '❌ Invalid format: Missing "cards" array in root object' + + if not isinstance(data["cards"], list): + return '❌ Invalid format: "cards" must be an array' + + if not data["cards"]: + return "❌ No cards found in the data" + + card_count = len(data["cards"]) + if input_format == "APKG": + return f"✅ Successfully imported and validated {card_count} cards from APKG file" + else: + # For JSON input, perform additional validation + is_valid, message = validate_flashcards(content) + return f"✅ {message}" if is_valid else f"❌ {message}" + + except json.JSONDecodeError as je: + line_col = f" (line {je.lineno}, column {je.colno})" if hasattr(je, 'lineno') else "" + return f"❌ Invalid JSON format: {str(je)}{line_col}" + except Exception as e: + return f"❌ Validation error: {str(e)}" + +# +# End of Anki.py +############################################################################################################ diff --git a/App_Function_Libraries/Third_Party/Semantic_Scholar.py b/App_Function_Libraries/Third_Party/Semantic_Scholar.py new file mode 100644 index 0000000000000000000000000000000000000000..d20610d1e459c331db567c350a4a3408c4e5c6ec --- /dev/null +++ b/App_Function_Libraries/Third_Party/Semantic_Scholar.py @@ -0,0 +1,162 @@ +# Semantic_Scholar.py +# Description: This file contains the functions to interact with the Semantic Scholar API +# +# Imports +from typing import List, Dict, Any + +import requests +# +#################################################################################################### +# +# Functions + +# Constants +FIELDS_OF_STUDY = [ + "Computer Science", "Medicine", "Chemistry", "Biology", "Materials Science", + "Physics", "Geology", "Psychology", "Art", "History", "Geography", + "Sociology", "Business", "Political Science", "Economics", "Philosophy", + "Mathematics", "Engineering", "Environmental Science", + "Agricultural and Food Sciences", "Education", "Law", "Linguistics" +] + +PUBLICATION_TYPES = [ + "Review", "JournalArticle", "CaseReport", "ClinicalTrial", "Conference", + "Dataset", "Editorial", "LettersAndComments", "MetaAnalysis", "News", + "Study", "Book", "BookSection" +] + + +def search_papers( + query: str, + page: int, + fields_of_study: List[str], + publication_types: List[str], + year_range: str, + venue: str, + min_citations: int, + open_access_only: bool, + limit: int = 10 +) -> Dict[str, Any]: + """Search for papers using the Semantic Scholar API with all available filters""" + if not query.strip(): + return {"total": 0, "offset": 0, "next": 0, "data": []} + + try: + url = "https://api.semanticscholar.org/graph/v1/paper/search" + params = { + "query": query, + "offset": page * limit, + "limit": limit, + "fields": "title,abstract,year,citationCount,authors,venue,openAccessPdf,url,publicationTypes,publicationDate" + } + + # Add optional filters + if fields_of_study: + params["fieldsOfStudy"] = ",".join(fields_of_study) + if publication_types: + params["publicationTypes"] = ",".join(publication_types) + if venue: + params["venue"] = venue + if min_citations: + params["minCitationCount"] = str(min_citations) + if open_access_only: + params["openAccessPdf"] = "" + if year_range: + try: + if "-" in year_range: + start_year, end_year = year_range.split("-") + params["year"] = f"{start_year.strip()}-{end_year.strip()}" + else: + params["year"] = year_range.strip() + except ValueError: + pass + + response = requests.get(url, params=params) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + return {"error": f"API Error: {str(e)}", "total": 0, "offset": 0, "data": []} + + +def get_paper_details(paper_id): + """Get detailed information about a specific paper""" + try: + url = f"https://api.semanticscholar.org/graph/v1/paper/{paper_id}" + params = { + "fields": "title,abstract,year,citationCount,authors,venue,openAccessPdf,url,references,citations" + } + response = requests.get(url, params=params) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + return {"error": f"API Error: {str(e)}"} + + +def format_paper_info(paper: Dict[str, Any]) -> str: + """Format paper information for display""" + authors = ", ".join([author["name"] for author in paper.get("authors", [])]) + year = f"Year: {paper.get('year', 'N/A')}" + venue = f"Venue: {paper.get('venue', 'N/A')}" + citations = f"Citations: {paper.get('citationCount', 0)}" + pub_types = f"Types: {', '.join(paper.get('publicationTypes', ['N/A']))}" + + pdf_link = "" + if paper.get("openAccessPdf"): + pdf_link = f"\nPDF: {paper['openAccessPdf']['url']}" + + s2_link = f"\nSemantic Scholar: {paper.get('url', '')}" + + formatted = f"""# {paper.get('title', 'No Title')} + +Authors: {authors} +{year} | {venue} | {citations} +{pub_types} + +Abstract: +{paper.get('abstract', 'No abstract available')} + +Links:{pdf_link}{s2_link} +""" + return formatted + + +def search_and_display( + query: str, + page: int, + fields_of_study: List[str], + publication_types: List[str], + year_range: str, + venue: str, + min_citations: int, + open_access_only: bool +) -> tuple[str, int, int, str]: + """Search for papers and return formatted results with pagination info""" + result = search_papers( + query, page, fields_of_study, publication_types, + year_range, venue, min_citations, open_access_only + ) + + if "error" in result: + return result["error"], 0, 0, "0" + + if not result["data"]: + return "No results found.", 0, 0, "0" + + papers = result["data"] + total_results = int(result.get("total", "0")) + max_pages = (total_results + 9) // 10 # Ceiling division + + results = [] + for paper in papers: + results.append(format_paper_info(paper)) + + formatted_results = "\n\n---\n\n".join(results) + + # Add pagination information + pagination_info = f"\n\n---\n\nShowing results {result['offset'] + 1}-{result['offset'] + len(papers)} of {total_results}" + + return formatted_results + pagination_info, page, max_pages - 1, str(total_results) + +# +# End of Semantic_Scholar.py +#################################################################################################### diff --git a/App_Function_Libraries/Utils/Utils.py b/App_Function_Libraries/Utils/Utils.py index 1e0df0ccfc04408c3782edab7cabab3243cdcd5e..569253a5d6f517f18291cbd46cefd4826b909c04 100644 --- a/App_Function_Libraries/Utils/Utils.py +++ b/App_Function_Libraries/Utils/Utils.py @@ -95,8 +95,6 @@ def cleanup_downloads(): ####################################################################################################################### # Config loading # - - def load_comprehensive_config(): # Get the directory of the current script (Utils.py) current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -126,25 +124,33 @@ def load_comprehensive_config(): def get_project_root(): - # Get the directory of the current file (Utils.py) + """Get the absolute path to the project root directory.""" current_dir = os.path.dirname(os.path.abspath(__file__)) - # Go up two levels to reach the project root - # Assuming the structure is: project_root/App_Function_Libraries/Utils/Utils.py project_root = os.path.dirname(os.path.dirname(current_dir)) + logging.debug(f"Project root: {project_root}") return project_root + def get_database_dir(): - """Get the database directory (/tldw/Databases/).""" + """Get the absolute path to the database directory.""" db_dir = os.path.join(get_project_root(), 'Databases') + os.makedirs(db_dir, exist_ok=True) logging.debug(f"Database directory: {db_dir}") return db_dir -def get_database_path(db_name: Union[str, os.PathLike[AnyStr]]) -> str: - """Get the full path for a database file.""" - path = os.path.join(get_database_dir(), str(db_name)) - logging.debug(f"Database path for {db_name}: {path}") + +def get_database_path(db_name: str) -> str: + """ + Get the full absolute path for a database file. + Ensures the path is always within the Databases directory. + """ + # Remove any directory traversal attempts + safe_db_name = os.path.basename(db_name) + path = os.path.join(get_database_dir(), safe_db_name) + logging.debug(f"Database path for {safe_db_name}: {path}") return path + def get_project_relative_path(relative_path: Union[str, os.PathLike[AnyStr]]) -> str: """Convert a relative path to a path relative to the project root.""" path = os.path.join(get_project_root(), str(relative_path)) @@ -254,6 +260,8 @@ def load_and_log_configs(): logging.debug(f"Loaded Tabby API IP: {tabby_api_IP}") logging.debug(f"Loaded VLLM API URL: {vllm_api_url}") + # Retrieve default API choices from the configuration file + default_api = config.get('API', 'default_api', fallback='openai') # Retrieve output paths from the configuration file output_path = config.get('Paths', 'output_path', fallback='results') @@ -278,6 +286,10 @@ def load_and_log_configs(): # Prompts - FIXME prompt_path = config.get('Prompts', 'prompt_path', fallback='Databases/prompts.db') + # Auto-Save Values + save_character_chats = config.get('Auto-Save', 'save_character_chats', fallback='False') + save_rag_chats = config.get('Auto-Save', 'save_rag_chats', fallback='False') + return { 'api_keys': { 'anthropic': anthropic_api_key, @@ -340,13 +352,47 @@ def load_and_log_configs(): 'embedding_api_key': embedding_api_key, 'chunk_size': chunk_size, 'overlap': overlap - } + }, + 'auto-save': { + 'save_character_chats': save_character_chats, + 'save_rag_chats': save_rag_chats, + }, + 'default_api': default_api } except Exception as e: logging.error(f"Error loading config: {str(e)}") return None +global_api_endpoints = ["anthropic", "cohere", "groq", "openai", "huggingface", "openrouter", "deepseek", "mistral", "custom_openai_api", "llama", "ooba", "kobold", "tabby", "vllm", "ollama", "aphrodite"] + +# Setup Default API Endpoint +loaded_config_data = load_and_log_configs() +default_api_endpoint = loaded_config_data['default_api'] + +def format_api_name(api): + name_mapping = { + "openai": "OpenAI", + "anthropic": "Anthropic", + "cohere": "Cohere", + "groq": "Groq", + "huggingface": "HuggingFace", + "openrouter": "OpenRouter", + "deepseek": "DeepSeek", + "mistral": "Mistral", + "custom_openai_api": "Custom-OpenAI-API", + "llama": "Llama.cpp", + "ooba": "Ooba", + "kobold": "Kobold", + "tabby": "Tabbyapi", + "vllm": "VLLM", + "ollama": "Ollama", + "aphrodite": "Aphrodite" + } + return name_mapping.get(api, api.title()) +print(f"Default API Endpoint: {default_api_endpoint}") + + # # End of Config loading diff --git a/App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py b/App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py index 7b3efbe85be337a7223c75c0c4c9b7988766b38e..4ff3b1f44e00fe3c08f26f9ca3e513dc67c9a318 100644 --- a/App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py +++ b/App_Function_Libraries/Web_Scraping/Article_Extractor_Lib.py @@ -13,29 +13,38 @@ #################### # # Import necessary libraries +import hashlib +from datetime import datetime import json import logging -# 3rd-Party Imports -import asyncio import os import tempfile -from datetime import datetime -from typing import List, Dict, Union +from typing import Any, Dict, List, Union, Optional, Tuple +# +# 3rd-Party Imports +import asyncio from urllib.parse import urljoin, urlparse from xml.dom import minidom -from playwright.async_api import async_playwright +import xml.etree.ElementTree as ET +# +# External Libraries from bs4 import BeautifulSoup +import pandas as pd +from playwright.async_api import async_playwright import requests import trafilatura -import xml.etree.ElementTree as ET - - -# Import Local # +# Import Local +from App_Function_Libraries.DB.DB_Manager import ingest_article_to_db +from App_Function_Libraries.Summarization.Summarization_General_Lib import summarize ####################################################################################################################### # Function Definitions # +################################################################# +# +# Scraping-related functions: + def get_page_title(url: str) -> str: try: response = requests.get(url) @@ -48,21 +57,24 @@ def get_page_title(url: str) -> str: return "Untitled" -async def scrape_article(url): +async def scrape_article(url: str, custom_cookies: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]: async def fetch_html(url: str) -> str: async with async_playwright() as p: browser = await p.chromium.launch(headless=True) context = await browser.new_context( - user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3") + user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" + ) + if custom_cookies: + await context.add_cookies(custom_cookies) page = await context.new_page() await page.goto(url) - await page.wait_for_load_state("networkidle") # Wait for the network to be idle + await page.wait_for_load_state("networkidle") content = await page.content() await browser.close() return content - # FIXME - Add option for extracting comments/tables/images def extract_article_data(html: str, url: str) -> dict: + # FIXME - Add option for extracting comments/tables/images downloaded = trafilatura.extract(html, include_comments=False, include_tables=False, include_images=False) metadata = trafilatura.extract_metadata(html) @@ -76,7 +88,16 @@ async def scrape_article(url): } if downloaded: - result['content'] = downloaded + # Add metadata to content + result['content'] = ContentMetadataHandler.format_content_with_metadata( + url=url, + content=downloaded, + pipeline="Trafilatura", + additional_metadata={ + "extracted_date": metadata.date if metadata and metadata.date else 'N/A', + "author": metadata.author if metadata and metadata.author else 'N/A' + } + ) result['extraction_successful'] = True if metadata: @@ -108,96 +129,178 @@ async def scrape_article(url): return article_data -def collect_internal_links(base_url: str) -> set: - visited = set() - to_visit = {base_url} +async def scrape_and_summarize_multiple( + urls: str, + custom_prompt_arg: Optional[str], + api_name: str, + api_key: Optional[str], + keywords: str, + custom_article_titles: Optional[str], + system_message: Optional[str] = None, + summarize_checkbox: bool = False, + custom_cookies: Optional[List[Dict[str, Any]]] = None, + temperature: float = 0.7 +) -> List[Dict[str, Any]]: + urls_list = [url.strip() for url in urls.split('\n') if url.strip()] + custom_titles = custom_article_titles.split('\n') if custom_article_titles else [] + + results = [] + errors = [] + + # Loop over each URL to scrape and optionally summarize + for i, url in enumerate(urls_list): + custom_title = custom_titles[i] if i < len(custom_titles) else None + try: + # Scrape the article + article = await scrape_article(url, custom_cookies=custom_cookies) + if article and article['extraction_successful']: + if custom_title: + article['title'] = custom_title + + # If summarization is requested + if summarize_checkbox: + content = article.get('content', '') + if content: + # Prepare prompts + system_message_final = system_message or "Act as a professional summarizer and summarize this article." + article_custom_prompt = custom_prompt_arg or "Act as a professional summarizer and summarize this article." + + # Summarize the content using the summarize function + summary = summarize( + input_data=content, + custom_prompt_arg=article_custom_prompt, + api_name=api_name, + api_key=api_key, + temp=temperature, + system_message=system_message_final + ) + article['summary'] = summary + logging.info(f"Summary generated for URL {url}") + else: + article['summary'] = "No content available to summarize." + logging.warning(f"No content to summarize for URL {url}") + else: + article['summary'] = None - while to_visit: - current_url = to_visit.pop() - if current_url in visited: - continue + results.append(article) + else: + error_message = f"Extraction unsuccessful for URL {url}" + errors.append(error_message) + logging.error(error_message) + except Exception as e: + error_message = f"Error processing URL {i + 1} ({url}): {str(e)}" + errors.append(error_message) + logging.error(error_message, exc_info=True) - try: - response = requests.get(current_url) - response.raise_for_status() - soup = BeautifulSoup(response.text, 'html.parser') + if errors: + logging.error("\n".join(errors)) - # Collect internal links - for link in soup.find_all('a', href=True): - full_url = urljoin(base_url, link['href']) - # Only process links within the same domain - if urlparse(full_url).netloc == urlparse(base_url).netloc: - if full_url not in visited: - to_visit.add(full_url) + if not results: + logging.error("No articles were successfully scraped and summarized/analyzed.") + return [] - visited.add(current_url) - except requests.RequestException as e: - logging.error(f"Error visiting {current_url}: {e}") - continue + return results - return visited +def scrape_and_no_summarize_then_ingest(url, keywords, custom_article_title): + try: + # Step 1: Scrape the article + article_data = asyncio.run(scrape_article(url)) + print(f"Scraped Article Data: {article_data}") # Debugging statement + if not article_data: + return "Failed to scrape the article." -def generate_temp_sitemap_from_links(links: set) -> str: + # Use the custom title if provided, otherwise use the scraped title + title = custom_article_title.strip() if custom_article_title else article_data.get('title', 'Untitled') + author = article_data.get('author', 'Unknown') + content = article_data.get('content', '') + ingestion_date = datetime.now().strftime('%Y-%m-%d') + + print(f"Title: {title}, Author: {author}, Content Length: {len(content)}") # Debugging statement + + # Step 2: Ingest the article into the database + ingestion_result = ingest_article_to_db(url, title, author, content, keywords, ingestion_date, None, None) + + # When displaying content, we might want to strip metadata + display_content = ContentMetadataHandler.strip_metadata(content) + return f"Title: {title}\nAuthor: {author}\nIngestion Result: {ingestion_result}\n\nArticle Contents: {display_content}" + except Exception as e: + logging.error(f"Error processing URL {url}: {str(e)}") + return f"Failed to process URL {url}: {str(e)}" + + +def scrape_from_filtered_sitemap(sitemap_file: str, filter_function) -> list: """ - Generate a temporary sitemap file from collected links and return its path. + Scrape articles from a sitemap file, applying an additional filter function. - :param links: A set of URLs to include in the sitemap - :return: Path to the temporary sitemap file + :param sitemap_file: Path to the sitemap file + :param filter_function: A function that takes a URL and returns True if it should be scraped + :return: List of scraped articles """ - # Create the root element - urlset = ET.Element("urlset") - urlset.set("xmlns", "http://www.sitemaps.org/schemas/sitemap/0.9") + try: + tree = ET.parse(sitemap_file) + root = tree.getroot() - # Add each link to the sitemap - for link in links: - url = ET.SubElement(urlset, "url") - loc = ET.SubElement(url, "loc") - loc.text = link - lastmod = ET.SubElement(url, "lastmod") - lastmod.text = datetime.now().strftime("%Y-%m-%d") - changefreq = ET.SubElement(url, "changefreq") - changefreq.text = "daily" - priority = ET.SubElement(url, "priority") - priority.text = "0.5" + articles = [] + for url in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'): + if filter_function(url.text): + article_data = scrape_article(url.text) + if article_data: + articles.append(article_data) - # Create the tree and get it as a string - xml_string = ET.tostring(urlset, 'utf-8') + return articles + except ET.ParseError as e: + logging.error(f"Error parsing sitemap: {e}") + return [] - # Pretty print the XML - pretty_xml = minidom.parseString(xml_string).toprettyxml(indent=" ") - # Create a temporary file - with tempfile.NamedTemporaryFile(mode="w", suffix=".xml", delete=False) as temp_file: - temp_file.write(pretty_xml) - temp_file_path = temp_file.name +def is_content_page(url: str) -> bool: + """ + Determine if a URL is likely to be a content page. + This is a basic implementation and may need to be adjusted based on the specific website structure. - logging.info(f"Temporary sitemap created at: {temp_file_path}") - return temp_file_path + :param url: The URL to check + :return: True if the URL is likely a content page, False otherwise + """ + #Add more specific checks here based on the website's structure + # Exclude common non-content pages + exclude_patterns = [ + '/tag/', '/category/', '/author/', '/search/', '/page/', + 'wp-content', 'wp-includes', 'wp-json', 'wp-admin', + 'login', 'register', 'cart', 'checkout', 'account', + '.jpg', '.png', '.gif', '.pdf', '.zip' + ] + return not any(pattern in url.lower() for pattern in exclude_patterns) +def scrape_and_convert_with_filter(source: str, output_file: str, filter_function=is_content_page, level: int = None): + """ + Scrape articles from a sitemap or by URL level, apply filtering, and convert to a single markdown file. -def generate_sitemap_for_url(url: str) -> List[Dict[str, str]]: + :param source: URL of the sitemap, base URL for level-based scraping, or path to a local sitemap file + :param output_file: Path to save the output markdown file + :param filter_function: Function to filter URLs (default is is_content_page) + :param level: URL level for scraping (None if using sitemap) """ - Generate a sitemap for the given URL using the create_filtered_sitemap function. + if level is not None: + # Scraping by URL level + articles = scrape_by_url_level(source, level) + articles = [article for article in articles if filter_function(article['url'])] + elif source.startswith('http'): + # Scraping from online sitemap + articles = scrape_from_sitemap(source) + articles = [article for article in articles if filter_function(article['url'])] + else: + # Scraping from local sitemap file + articles = scrape_from_filtered_sitemap(source, filter_function) - Args: - url (str): The base URL to generate the sitemap for + articles = [article for article in articles if filter_function(article['url'])] + markdown_content = convert_to_markdown(articles) - Returns: - List[Dict[str, str]]: A list of dictionaries, each containing 'url' and 'title' keys - """ - with tempfile.NamedTemporaryFile(mode="w+", suffix=".xml", delete=False) as temp_file: - create_filtered_sitemap(url, temp_file.name, is_content_page) - temp_file.seek(0) - tree = ET.parse(temp_file.name) - root = tree.getroot() + with open(output_file, 'w', encoding='utf-8') as f: + f.write(markdown_content) - sitemap = [] - for url_elem in root.findall(".//{http://www.sitemaps.org/schemas/sitemap/0.9}url"): - loc = url_elem.find("{http://www.sitemaps.org/schemas/sitemap/0.9}loc").text - sitemap.append({"url": loc, "title": loc.split("/")[-1] or url}) # Use the last part of the URL as a title + logging.info(f"Scraped and filtered content saved to {output_file}") - return sitemap async def scrape_entire_site(base_url: str) -> List[Dict]: """ @@ -267,37 +370,103 @@ def scrape_from_sitemap(sitemap_url: str) -> list: logging.error(f"Error fetching sitemap: {e}") return [] +# +# End of Scraping Functions +####################################################### +# +# Sitemap/Crawling-related Functions -def convert_to_markdown(articles: list) -> str: - """Convert a list of article data into a single markdown document.""" - markdown = "" - for article in articles: - markdown += f"# {article['title']}\n\n" - markdown += f"Author: {article['author']}\n" - markdown += f"Date: {article['date']}\n\n" - markdown += f"{article['content']}\n\n" - markdown += "---\n\n" # Separator between articles - return markdown +def collect_internal_links(base_url: str) -> set: + visited = set() + to_visit = {base_url} -def is_content_page(url: str) -> bool: + while to_visit: + current_url = to_visit.pop() + if current_url in visited: + continue + + try: + response = requests.get(current_url) + response.raise_for_status() + soup = BeautifulSoup(response.text, 'html.parser') + + # Collect internal links + for link in soup.find_all('a', href=True): + full_url = urljoin(base_url, link['href']) + # Only process links within the same domain + if urlparse(full_url).netloc == urlparse(base_url).netloc: + if full_url not in visited: + to_visit.add(full_url) + + visited.add(current_url) + except requests.RequestException as e: + logging.error(f"Error visiting {current_url}: {e}") + continue + + return visited + + +def generate_temp_sitemap_from_links(links: set) -> str: """ - Determine if a URL is likely to be a content page. - This is a basic implementation and may need to be adjusted based on the specific website structure. + Generate a temporary sitemap file from collected links and return its path. - :param url: The URL to check - :return: True if the URL is likely a content page, False otherwise + :param links: A set of URLs to include in the sitemap + :return: Path to the temporary sitemap file """ - #Add more specific checks here based on the website's structure - # Exclude common non-content pages - exclude_patterns = [ - '/tag/', '/category/', '/author/', '/search/', '/page/', - 'wp-content', 'wp-includes', 'wp-json', 'wp-admin', - 'login', 'register', 'cart', 'checkout', 'account', - '.jpg', '.png', '.gif', '.pdf', '.zip' - ] - return not any(pattern in url.lower() for pattern in exclude_patterns) + # Create the root element + urlset = ET.Element("urlset") + urlset.set("xmlns", "http://www.sitemaps.org/schemas/sitemap/0.9") + + # Add each link to the sitemap + for link in links: + url = ET.SubElement(urlset, "url") + loc = ET.SubElement(url, "loc") + loc.text = link + lastmod = ET.SubElement(url, "lastmod") + lastmod.text = datetime.now().strftime("%Y-%m-%d") + changefreq = ET.SubElement(url, "changefreq") + changefreq.text = "daily" + priority = ET.SubElement(url, "priority") + priority.text = "0.5" + + # Create the tree and get it as a string + xml_string = ET.tostring(urlset, 'utf-8') + + # Pretty print the XML + pretty_xml = minidom.parseString(xml_string).toprettyxml(indent=" ") + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".xml", delete=False) as temp_file: + temp_file.write(pretty_xml) + temp_file_path = temp_file.name + + logging.info(f"Temporary sitemap created at: {temp_file_path}") + return temp_file_path + + +def generate_sitemap_for_url(url: str) -> List[Dict[str, str]]: + """ + Generate a sitemap for the given URL using the create_filtered_sitemap function. + + Args: + url (str): The base URL to generate the sitemap for + + Returns: + List[Dict[str, str]]: A list of dictionaries, each containing 'url' and 'title' keys + """ + with tempfile.NamedTemporaryFile(mode="w+", suffix=".xml", delete=False) as temp_file: + create_filtered_sitemap(url, temp_file.name, is_content_page) + temp_file.seek(0) + tree = ET.parse(temp_file.name) + root = tree.getroot() + + sitemap = [] + for url_elem in root.findall(".//{http://www.sitemaps.org/schemas/sitemap/0.9}url"): + loc = url_elem.find("{http://www.sitemaps.org/schemas/sitemap/0.9}loc").text + sitemap.append({"url": loc, "title": loc.split("/")[-1] or url}) # Use the last part of the URL as a title + + return sitemap def create_filtered_sitemap(base_url: str, output_file: str, filter_function): """ @@ -323,61 +492,44 @@ def create_filtered_sitemap(base_url: str, output_file: str, filter_function): print(f"Filtered sitemap saved to {output_file}") -def scrape_from_filtered_sitemap(sitemap_file: str, filter_function) -> list: - """ - Scrape articles from a sitemap file, applying an additional filter function. - - :param sitemap_file: Path to the sitemap file - :param filter_function: A function that takes a URL and returns True if it should be scraped - :return: List of scraped articles - """ - try: - tree = ET.parse(sitemap_file) - root = tree.getroot() - - articles = [] - for url in root.findall('.//{http://www.sitemaps.org/schemas/sitemap/0.9}loc'): - if filter_function(url.text): - article_data = scrape_article(url.text) - if article_data: - articles.append(article_data) - - return articles - except ET.ParseError as e: - logging.error(f"Error parsing sitemap: {e}") - return [] +# +# End of Crawling Functions +################################################################# +# +# Utility Functions +def convert_to_markdown(articles: list) -> str: + """Convert a list of article data into a single markdown document.""" + markdown = "" + for article in articles: + markdown += f"# {article['title']}\n\n" + markdown += f"Author: {article['author']}\n" + markdown += f"Date: {article['date']}\n\n" + markdown += f"{article['content']}\n\n" + markdown += "---\n\n" # Separator between articles + return markdown -def scrape_and_convert_with_filter(source: str, output_file: str, filter_function=is_content_page, level: int = None): - """ - Scrape articles from a sitemap or by URL level, apply filtering, and convert to a single markdown file. +def compute_content_hash(content: str) -> str: + return hashlib.sha256(content.encode('utf-8')).hexdigest() - :param source: URL of the sitemap, base URL for level-based scraping, or path to a local sitemap file - :param output_file: Path to save the output markdown file - :param filter_function: Function to filter URLs (default is is_content_page) - :param level: URL level for scraping (None if using sitemap) - """ - if level is not None: - # Scraping by URL level - articles = scrape_by_url_level(source, level) - articles = [article for article in articles if filter_function(article['url'])] - elif source.startswith('http'): - # Scraping from online sitemap - articles = scrape_from_sitemap(source) - articles = [article for article in articles if filter_function(article['url'])] +def load_hashes(filename: str) -> Dict[str, str]: + if os.path.exists(filename): + with open(filename, 'r') as f: + return json.load(f) else: - # Scraping from local sitemap file - articles = scrape_from_filtered_sitemap(source, filter_function) + return {} - articles = [article for article in articles if filter_function(article['url'])] - markdown_content = convert_to_markdown(articles) +def save_hashes(hashes: Dict[str, str], filename: str): + with open(filename, 'w') as f: + json.dump(hashes, f) - with open(output_file, 'w', encoding='utf-8') as f: - f.write(markdown_content) - - logging.info(f"Scraped and filtered content saved to {output_file}") +def has_page_changed(url: str, new_hash: str, stored_hashes: Dict[str, str]) -> bool: + old_hash = stored_hashes.get(url) + return old_hash != new_hash +# +# ################################################### # # Bookmark Parsing Functions @@ -497,6 +649,72 @@ def collect_bookmarks(file_path: str) -> Dict[str, Union[str, List[str]]]: logging.error(f"Error loading bookmarks: {e}") return {} + +def parse_csv_urls(file_path: str) -> Dict[str, Union[str, List[str]]]: + """ + Parse URLs from a CSV file. The CSV should have at minimum a 'url' column, + and optionally a 'title' or 'name' column. + + :param file_path: Path to the CSV file + :return: Dictionary with titles/names as keys and URLs as values + """ + try: + # Read CSV file + df = pd.read_csv(file_path) + + # Check if required columns exist + if 'url' not in df.columns: + raise ValueError("CSV must contain a 'url' column") + + # Initialize result dictionary + urls_dict = {} + + # Determine which column to use as key + key_column = next((col for col in ['title', 'name'] if col in df.columns), None) + + for idx in range(len(df)): + url = df.iloc[idx]['url'].strip() + + # Use title/name if available, otherwise use URL as key + if key_column: + key = df.iloc[idx][key_column].strip() + else: + key = f"Article {idx + 1}" + + # Handle duplicate keys + if key in urls_dict: + if isinstance(urls_dict[key], list): + urls_dict[key].append(url) + else: + urls_dict[key] = [urls_dict[key], url] + else: + urls_dict[key] = url + + return urls_dict + + except pd.errors.EmptyDataError: + logging.error("The CSV file is empty") + return {} + except Exception as e: + logging.error(f"Error parsing CSV file: {str(e)}") + return {} + + +def collect_urls_from_file(file_path: str) -> Dict[str, Union[str, List[str]]]: + """ + Unified function to collect URLs from either bookmarks or CSV files. + + :param file_path: Path to the file (bookmarks or CSV) + :return: Dictionary with names as keys and URLs as values + """ + _, ext = os.path.splitext(file_path) + ext = ext.lower() + + if ext == '.csv': + return parse_csv_urls(file_path) + else: + return collect_bookmarks(file_path) + # Usage: # from Article_Extractor_Lib import collect_bookmarks # @@ -523,6 +741,138 @@ def collect_bookmarks(file_path: str) -> Dict[str, Union[str, List[str]]]: # End of Bookmarking Parsing Functions ##################################################################### + +##################################################################### # +# Article Scraping Metadata Functions + +class ContentMetadataHandler: + """Handles the addition and parsing of metadata for scraped content.""" + + METADATA_START = "[METADATA]" + METADATA_END = "[/METADATA]" + + @staticmethod + def format_content_with_metadata( + url: str, + content: str, + pipeline: str = "Trafilatura", + additional_metadata: Optional[Dict[str, Any]] = None + ) -> str: + """ + Format content with metadata header. + + Args: + url: The source URL + content: The scraped content + pipeline: The scraping pipeline used + additional_metadata: Optional dictionary of additional metadata to include + + Returns: + Formatted content with metadata header + """ + metadata = { + "url": url, + "ingestion_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "content_hash": hashlib.sha256(content.encode('utf-8')).hexdigest(), + "scraping_pipeline": pipeline + } + + # Add any additional metadata + if additional_metadata: + metadata.update(additional_metadata) + + formatted_content = f"""{ContentMetadataHandler.METADATA_START} +{json.dumps(metadata, indent=2)} +{ContentMetadataHandler.METADATA_END} + +{content}""" + + return formatted_content + + @staticmethod + def extract_metadata(content: str) -> Tuple[Dict[str, Any], str]: + """ + Extract metadata and content separately. + + Args: + content: The full content including metadata + + Returns: + Tuple of (metadata dict, clean content) + """ + try: + metadata_start = content.index(ContentMetadataHandler.METADATA_START) + len( + ContentMetadataHandler.METADATA_START) + metadata_end = content.index(ContentMetadataHandler.METADATA_END) + metadata_json = content[metadata_start:metadata_end].strip() + metadata = json.loads(metadata_json) + clean_content = content[metadata_end + len(ContentMetadataHandler.METADATA_END):].strip() + return metadata, clean_content + except (ValueError, json.JSONDecodeError) as e: + return {}, content + + @staticmethod + def has_metadata(content: str) -> bool: + """ + Check if content contains metadata. + + Args: + content: The content to check + + Returns: + bool: True if metadata is present + """ + return (ContentMetadataHandler.METADATA_START in content and + ContentMetadataHandler.METADATA_END in content) + + @staticmethod + def strip_metadata(content: str) -> str: + """ + Remove metadata from content if present. + + Args: + content: The content to strip metadata from + + Returns: + Content without metadata + """ + try: + metadata_end = content.index(ContentMetadataHandler.METADATA_END) + return content[metadata_end + len(ContentMetadataHandler.METADATA_END):].strip() + except ValueError: + return content + + @staticmethod + def get_content_hash(content: str) -> str: + """ + Get hash of content without metadata. + + Args: + content: The content to hash + + Returns: + SHA-256 hash of the clean content + """ + clean_content = ContentMetadataHandler.strip_metadata(content) + return hashlib.sha256(clean_content.encode('utf-8')).hexdigest() + + @staticmethod + def content_changed(old_content: str, new_content: str) -> bool: + """ + Check if content has changed by comparing hashes. + + Args: + old_content: Previous version of content + new_content: New version of content + + Returns: + bool: True if content has changed + """ + old_hash = ContentMetadataHandler.get_content_hash(old_content) + new_hash = ContentMetadataHandler.get_content_hash(new_content) + return old_hash != new_hash + # +# End of Article_Extractor_Lib.py ####################################################################################################################### diff --git a/App_Function_Libraries/models/Whisper/models--Systran--faster-distil-whisper-large-v2/refs/main b/App_Function_Libraries/models/Whisper/models--Systran--faster-distil-whisper-large-v2/refs/main new file mode 100644 index 0000000000000000000000000000000000000000..75d6d0971c77e4f2c896dcba8c30576425de219f --- /dev/null +++ b/App_Function_Libraries/models/Whisper/models--Systran--faster-distil-whisper-large-v2/refs/main @@ -0,0 +1 @@ +fe9b404fc56de3f7c38606ef9ba6fd83526d05e4 \ No newline at end of file