acres / app.py
ak3ra's picture
Refactor chat_function to unpack response tuple and return only the response
f139cfc
# app.py
import csv
import datetime
# from datetime import datetime
import io
import json
import logging
import os
from typing import Tuple, List, Any
import gradio as gr
import openai
from dotenv import load_dotenv
from slugify import slugify
from config import STUDY_FILES, OPENAI_API_KEY
from rag.rag_pipeline import RAGPipeline
from utils.helpers import (
append_to_study_files,
add_study_files_to_chromadb,
chromadb_client,
)
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.zotero_manager import ZoteroManager
from interface import create_chat_interface
from utils.pdf_processor import PDFProcessor
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
openai.api_key = OPENAI_API_KEY
# Initialize ChromaDB with study files
add_study_files_to_chromadb("study_files.json", "study_files_collection")
# Cache for RAG pipelines
rag_cache = {}
def get_rag_pipeline(study_name: str) -> RAGPipeline:
"""Get or create a RAGPipeline instance for the given study by querying ChromaDB."""
if study_name not in rag_cache:
collection = chromadb_client.get_or_create_collection("study_files_collection")
result = collection.get(ids=[study_name]) # Retrieve document by ID
if not result or len(result["metadatas"]) == 0:
raise ValueError(f"Invalid study name: {study_name}")
study_file = result["metadatas"][0].get("file_path")
if not study_file:
raise ValueError(f"File path not found for study name: {study_name}")
rag_cache[study_name] = RAGPipeline(study_file)
return rag_cache[study_name]
def get_study_info(study_name: str) -> str:
"""Retrieve information about the specified study."""
collection = chromadb_client.get_or_create_collection("study_files_collection")
result = collection.get(ids=[study_name]) # Query by study name (as a list)
logging.info(f"Result: ======> {result}")
if not result or len(result["metadatas"]) == 0:
raise ValueError(f"Invalid study name: {study_name}")
study_file = result["metadatas"][0].get("file_path")
logging.info(f"study_file: =======> {study_file}")
if not study_file:
raise ValueError(f"File path not found for study name: {study_name}")
with open(study_file, "r") as f:
data = json.load(f)
return f"### Number of documents: {len(data)}"
def markdown_table_to_csv(markdown_text: str) -> str:
"""Convert a markdown table to CSV format."""
lines = [line.strip() for line in markdown_text.split("\n") if line.strip()]
table_lines = [line for line in lines if line.startswith("|")]
if not table_lines:
return ""
csv_data = []
for line in table_lines:
if "---" in line:
continue
# Split by |, remove empty strings, and strip whitespace
cells = [cell.strip() for cell in line.split("|") if cell.strip()]
csv_data.append(cells)
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(csv_data)
return output.getvalue()
def cleanup_temp_files():
"""Clean up old temporary files."""
try:
current_time = datetime.datetime.now()
for file in os.listdir():
if file.startswith("study_export_") and file.endswith(".csv"):
file_time = datetime.datetime.fromtimestamp(os.path.getmtime(file))
# Calculate the time difference in seconds
time_difference = (current_time - file_time).total_seconds()
if time_difference > 20: # 5 minutes in seconds
try:
os.remove(file)
except Exception as e:
logging.warning(f"Failed to remove temp file {file}: {e}")
except Exception as e:
logging.warning(f"Error during cleanup: {e}")
def chat_function(message: str, study_name: str, prompt_type: str) -> str:
"""Process a chat message and generate a response using the RAG pipeline."""
if not message.strip():
return "Please enter a valid query."
rag = get_rag_pipeline(study_name)
logging.info(f"rag: ==> {rag}")
prompt = {
"Highlight": highlight_prompt,
"Evidence-based": evidence_based_prompt,
}.get(prompt_type)
response, _ = rag.query(message, prompt_template=prompt) # Unpack the tuple
return response
def process_zotero_library_items(
zotero_library_id: str, zotero_api_access_key: str
) -> str:
if not zotero_library_id or not zotero_api_access_key:
return "Please enter your zotero library Id and API Access Key"
zotero_library_id = zotero_library_id
zotero_library_type = "user" # or "group"
zotero_api_access_key = zotero_api_access_key
message = ""
try:
zotero_manager = ZoteroManager(
zotero_library_id, zotero_library_type, zotero_api_access_key
)
zotero_collections = zotero_manager.get_collections()
zotero_collection_lists = zotero_manager.list_zotero_collections(
zotero_collections
)
filtered_zotero_collection_lists = (
zotero_manager.filter_and_return_collections_with_items(
zotero_collection_lists
)
)
study_files_data = {} # Dictionary to collect items for ChromaDB
for collection in filtered_zotero_collection_lists:
collection_name = collection.get("name")
if collection_name not in STUDY_FILES:
collection_key = collection.get("key")
collection_items = zotero_manager.get_collection_items(collection_key)
zotero_collection_items = (
zotero_manager.get_collection_zotero_items_by_key(collection_key)
)
# Export zotero collection items to json
zotero_items_json = zotero_manager.zotero_items_to_json(
zotero_collection_items
)
export_file = f"{slugify(collection_name)}_zotero_items.json"
zotero_manager.write_zotero_items_to_json_file(
zotero_items_json, f"data/{export_file}"
)
append_to_study_files(
"study_files.json", collection_name, f"data/{export_file}"
)
# Collect for ChromaDB
study_files_data[collection_name] = f"data/{export_file}"
# Update in-memory STUDY_FILES for reference in current session
STUDY_FILES.update({collection_name: f"data/{export_file}"})
logging.info(f"STUDY_FILES: {STUDY_FILES}")
# After loop, add all collected data to ChromaDB
add_study_files_to_chromadb("study_files.json", "study_files_collection")
message = "Successfully processed items in your zotero library"
except Exception as e:
message = f"Error process your zotero library: {str(e)}"
return message
def process_multi_input(text, study_name, prompt_type):
# Split input based on commas and strip any extra spaces
variable_list = [word.strip().upper() for word in text.split(",")]
user_message = f"Extract and present in a tabular format the following variables for each {study_name} study: {', '.join(variable_list)}"
logging.info(f"User message: ==> {user_message}")
response = chat_function(user_message, study_name, prompt_type)
return [response, gr.update(visible=True)]
def download_as_csv(markdown_content):
"""Convert markdown table to CSV and provide for download."""
if not markdown_content:
return None
csv_content = markdown_table_to_csv(markdown_content)
if not csv_content:
return None
# Create temporary file with actual content
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
temp_path = f"study_export_{timestamp}.csv"
with open(temp_path, "w", newline="", encoding="utf-8") as f:
f.write(csv_content)
return temp_path
# PDF Support
def process_pdf_uploads(files: List[gr.File], collection_name: str) -> str:
"""Process uploaded PDF files and add them to the system."""
if not files or not collection_name:
return "Please upload PDF files and provide a collection name"
try:
processor = PDFProcessor()
# Save uploaded files temporarily
file_paths = []
for file in files:
# Get the actual file path from the Gradio File object
if hasattr(file, "name"): # If it's already a path
temp_path = file.name
else: # If it needs to be saved
temp_path = os.path.join(processor.upload_dir, file.orig_name)
file.save(temp_path)
file_paths.append(temp_path)
# Process PDFs
output_path = processor.process_pdfs(file_paths, collection_name)
# Add to study files and ChromaDB
collection_id = f"pdf_{slugify(collection_name)}"
append_to_study_files("study_files.json", collection_id, output_path)
add_study_files_to_chromadb("study_files.json", "study_files_collection")
# Cleanup temporary files if they were created by us
for path in file_paths:
if path.startswith(processor.upload_dir):
try:
os.remove(path)
except Exception as e:
logger.warning(f"Failed to remove temporary file {path}: {e}")
return f"Successfully processed PDFs into collection: {collection_id}"
except Exception as e:
logger.error(f"Error in process_pdf_uploads: {str(e)}")
return f"Error processing PDF files: {str(e)}"
def chat_response(
message: str,
history: List[Tuple[str, str]],
study_name: str,
pdf_processor: PDFProcessor,
) -> Tuple[List[Tuple[str, str]], str, Any]:
"""Generate chat response and update history."""
if not message.strip():
return history, None, None
rag = get_rag_pipeline(study_name)
response, source_info = rag.query(message)
history.append((message, response))
# Generate PDF preview if source information is available
preview_image = None
if (
source_info
and source_info.get("source_file")
and source_info.get("page_numbers")
):
try:
# Get the first page number from the source
page_num = source_info["page_numbers"][0]
preview_image = pdf_processor.render_page(
source_info["source_file"], int(page_num)
)
except Exception as e:
logger.error(f"Error generating PDF preview: {str(e)}")
return history, preview_image
def create_gr_interface() -> gr.Blocks:
"""Create and configure the Gradio interface for the RAG platform."""
with gr.Blocks() as demo:
gr.Markdown("# ACRES RAG Platform")
with gr.Tabs() as tabs:
# Tab 1: Original Study Analysis Interface
with gr.Tab("Study Analysis"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Zotero Credentials")
zotero_library_id = gr.Textbox(
label="Zotero Library ID",
type="password",
placeholder="Enter Your Zotero Library ID here...",
)
zotero_api_access_key = gr.Textbox(
label="Zotero API Access Key",
type="password",
placeholder="Enter Your Zotero API Access Key...",
)
process_zotero_btn = gr.Button("Process your Zotero Library")
zotero_output = gr.Markdown(label="Zotero")
gr.Markdown("### Study Information")
collection = chromadb_client.get_or_create_collection(
"study_files_collection"
)
all_documents = collection.query(
query_texts=[""], n_results=1000
)
study_choices = [
doc_id
for doc_id in all_documents.get("ids")[0]
if all_documents
]
study_dropdown = gr.Dropdown(
choices=study_choices,
label="Select Study",
value=(study_choices[0] if study_choices else None),
)
study_info = gr.Markdown(label="Study Details")
prompt_type = gr.Radio(
["Default", "Highlight", "Evidence-based"],
label="Prompt Type",
value="Default",
)
with gr.Column(scale=3):
gr.Markdown("### Study Variables")
with gr.Row():
study_variables = gr.Textbox(
show_label=False,
placeholder="Type your variables separated by commas e.g (Study ID, Study Title, Authors etc)",
scale=4,
lines=1,
autofocus=True,
)
submit_btn = gr.Button("Submit", scale=1)
answer_output = gr.Markdown(label="Answer")
download_btn = gr.DownloadButton(
"Download as CSV",
variant="primary",
size="sm",
scale=1,
visible=False,
)
# Tab 2: PDF Chat Interface
with gr.Tab("PDF Chat"):
pdf_processor = PDFProcessor()
with gr.Row():
# Left column: Chat and Input
with gr.Column(scale=7):
chat_history = gr.Chatbot(
value=[], height=600, show_label=False
)
with gr.Row():
query_input = gr.Textbox(
show_label=False,
placeholder="Ask a question about your PDFs...",
scale=8,
)
chat_submit_btn = gr.Button(
"Send", scale=2, variant="primary"
)
# Right column: PDF Preview and Upload
with gr.Column(scale=3):
pdf_preview = gr.Image(label="Source Page", height=600)
with gr.Row():
pdf_files = gr.File(
file_count="multiple",
file_types=[".pdf"],
label="Upload PDFs",
)
with gr.Row():
collection_name = gr.Textbox(
label="Collection Name",
placeholder="Name this PDF collection...",
)
with gr.Row():
upload_btn = gr.Button("Process PDFs", variant="primary")
pdf_status = gr.Markdown()
current_collection = gr.State(value=None)
# Event handlers for Study Analysis tab
process_zotero_btn.click(
process_zotero_library_items,
inputs=[zotero_library_id, zotero_api_access_key],
outputs=[zotero_output],
)
study_dropdown.change(
get_study_info, inputs=[study_dropdown], outputs=[study_info]
)
submit_btn.click(
process_multi_input,
inputs=[study_variables, study_dropdown, prompt_type],
outputs=[answer_output, download_btn],
)
download_btn.click(
fn=download_as_csv, inputs=[answer_output], outputs=[download_btn]
).then(fn=cleanup_temp_files, inputs=None, outputs=None)
# Event handlers for PDF Chat tab
def handle_pdf_upload(files, name):
if not name:
return "Please provide a collection name", None
if not files:
return "Please select PDF files", None
try:
result = process_pdf_uploads(files, name)
collection_id = f"pdf_{slugify(name)}"
return result, collection_id
except Exception as e:
logger.error(f"Error in handle_pdf_upload: {str(e)}")
return f"Error: {str(e)}", None
upload_btn.click(
handle_pdf_upload,
inputs=[pdf_files, collection_name],
outputs=[pdf_status, current_collection],
)
def add_message(history, message):
"""Add user message to chat history."""
if not message.strip():
raise gr.Error("Please enter a message")
history = history + [(message, None)]
return history, "", None
def generate_chat_response(history, collection_id, pdf_processor):
"""Generate response for the last message in history."""
if not collection_id:
raise gr.Error("Please upload PDFs first")
if len(history) == 0:
return history, None
last_message = history[-1][0]
try:
# Get response and source info
rag = get_rag_pipeline(collection_id)
response, source_info = rag.query(last_message)
# Generate preview if source information is available
preview_image = None
if (
source_info
and source_info.get("source_file")
and source_info.get("page_number") is not None
):
try:
page_num = source_info["page_number"]
logger.info(f"Attempting to render page {page_num}")
preview_image = pdf_processor.render_page(
source_info["source_file"], page_num
)
if preview_image:
logger.info(
f"Successfully generated preview for page {page_num}"
)
else:
logger.warning(
f"Failed to generate preview for page {page_num}"
)
except Exception as e:
logger.error(f"Error generating PDF preview: {str(e)}")
preview_image = None
# Update history with response
history[-1] = (last_message, response)
return history, preview_image
except Exception as e:
logger.error(f"Error in generate_chat_response: {str(e)}")
history[-1] = (last_message, f"Error: {str(e)}")
return history, None
# Update PDF event handlers
upload_btn.click( # Change from pdf_files.upload to upload_btn.click
handle_pdf_upload,
inputs=[pdf_files, collection_name],
outputs=[pdf_status, current_collection],
)
# Fixed chat event handling
chat_submit_btn.click(
add_message,
inputs=[chat_history, query_input],
outputs=[chat_history, query_input, pdf_preview],
).success(
lambda h, c: generate_chat_response(h, c, pdf_processor),
inputs=[chat_history, current_collection],
outputs=[chat_history, pdf_preview],
)
return demo
demo = create_gr_interface()
if __name__ == "__main__":
# demo = create_gr_interface()
demo.launch(share=True, debug=True)