ak3ra commited on
Commit
daee42b
β€’
1 Parent(s): 8f8aff5

modularization

Browse files
app.py CHANGED
@@ -1,36 +1,54 @@
1
  import gradio as gr
2
- import os
3
- from rag_pipeline import RAGPipeline
4
- import openai
5
- openai.api_key = os.environ.get('OPENAI_API_KEY')
6
-
7
- # Initialize the RAG pipeline
8
- rag = RAGPipeline("metadata_map.json", "pdfs")
9
-
10
- def process_query(question, response_format):
11
- response = rag.query(question)
12
-
13
- if response_format == "Markdown":
14
- return response["markdown"]
15
  else:
16
- return response["raw"]
17
-
18
- # Define the Gradio interface
19
- iface = gr.Interface(
20
- fn=process_query,
21
- inputs=[
22
- gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question"),
23
- gr.Radio(["Markdown", "Raw Text"], label="Response Format", value="Markdown")
24
- ],
25
- outputs=gr.Markdown(label="Response"),
26
- title="Vaccine Coverage and Hesitancy Research QA",
27
- description="Ask questions about vaccine coverage and hesitancy. The system will provide answers based on the available research papers.",
28
- examples=[
29
- ["What are the main factors contributing to vaccine hesitancy?", "Markdown"],
30
- ["What are the current vaccine coverage rates in African countries?", "Raw Text"],
31
- ],
32
- allow_flagging="never"
33
- )
34
-
35
- # Launch the app
36
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from database.vaccine_coverage_db import VaccineCoverageDB
3
+ from rag.rag_pipeline import RAGPipeline
4
+ from utils.helpers import process_response
5
+ from config import DB_PATH, METADATA_FILE, PDF_DIR
6
+
7
+ # Initialize database and RAG pipeline
8
+ db = VaccineCoverageDB(DB_PATH)
9
+ rag = RAGPipeline(METADATA_FILE, PDF_DIR, use_semantic_splitter=True)
10
+
11
+
12
+ def query_rag(question, prompt_type):
13
+ if prompt_type == "Highlight":
14
+ response = rag.query(question, prompt_type="highlight")
15
  else:
16
+ response = rag.query(question, prompt_type="evidence_based")
17
+
18
+ processed = process_response(response)
19
+ return processed["markdown"]
20
+
21
+
22
+ def save_pdf(item_key):
23
+ attachments = db.get_attachments_for_item(item_key)
24
+ if attachments:
25
+ attachment_key = attachments[0]["key"]
26
+ output_path = f"{attachment_key}.pdf"
27
+ if db.save_pdf_to_file(attachment_key, output_path):
28
+ return f"PDF saved successfully to {output_path}"
29
+ return "Failed to save PDF or no attachments found"
30
+
31
+
32
+ # Gradio interface
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("# Vaccine Coverage Study RAG System")
35
+
36
+ with gr.Tab("Query"):
37
+ question_input = gr.Textbox(label="Enter your question")
38
+ prompt_type = gr.Radio(["Highlight", "Evidence-based"], label="Prompt Type")
39
+ query_button = gr.Button("Submit Query")
40
+ output = gr.Markdown(label="Response")
41
+
42
+ query_button.click(
43
+ query_rag, inputs=[question_input, prompt_type], outputs=output
44
+ )
45
+
46
+ with gr.Tab("Save PDF"):
47
+ item_key_input = gr.Textbox(label="Enter item key")
48
+ save_button = gr.Button("Save PDF")
49
+ save_output = gr.Textbox(label="Save Result")
50
+
51
+ save_button.click(save_pdf, inputs=item_key_input, outputs=save_output)
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()
database/vaccine_coverage_db.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import List, Dict, Any
3
+
4
+
5
+ class VaccineCoverageDB:
6
+ def __init__(self, db_path: str):
7
+ self.conn = sqlite3.connect(db_path)
8
+ self.conn.row_factory = sqlite3.Row
9
+
10
+ def get_all_items(self) -> List[Dict[str, Any]]:
11
+ cursor = self.conn.execute("SELECT * FROM items")
12
+ return [dict(row) for row in cursor.fetchall()]
13
+
14
+ def get_item_by_key(self, key: str) -> Dict[str, Any]:
15
+ cursor = self.conn.execute("SELECT * FROM items WHERE key = ?", (key,))
16
+ return dict(cursor.fetchone())
17
+
18
+ def get_attachments_for_item(self, item_key: str) -> List[Dict[str, Any]]:
19
+ cursor = self.conn.execute(
20
+ "SELECT * FROM attachments WHERE parent_key = ?", (item_key,)
21
+ )
22
+ return [dict(row) for row in cursor.fetchall()]
23
+
24
+ def get_pdf_content(self, attachment_key: str) -> bytes:
25
+ cursor = self.conn.execute(
26
+ "SELECT content FROM attachments WHERE key = ?", (attachment_key,)
27
+ )
28
+ result = cursor.fetchone()
29
+ return result["content"] if result else None
30
+
31
+ def save_pdf_to_file(self, attachment_key: str, output_path: str) -> bool:
32
+ pdf_content = self.get_pdf_content(attachment_key)
33
+ if pdf_content:
34
+ try:
35
+ with open(output_path, "wb") as f:
36
+ f.write(pdf_content)
37
+ return True
38
+ except Exception as e:
39
+ print(f"Error saving PDF: {str(e)}")
40
+ return False
41
+ else:
42
+ print(f"No PDF content found for attachment key: {attachment_key}")
43
+ return False
44
+
45
+ def close(self):
46
+ self.conn.close()
rag_pipeline.py β†’ rag/rag_pipeline.py RENAMED
File without changes