File size: 2,457 Bytes
9a8819b
0198065
9a8819b
 
 
 
fa330f5
9a8819b
0198065
1badf2f
9a8819b
1badf2f
 
 
 
9a8819b
 
2414c16
9a8819b
2414c16
9a8819b
 
 
 
 
2414c16
9a8819b
 
 
0198065
9a8819b
 
 
 
 
 
 
 
 
0198065
9a8819b
 
0198065
9a8819b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0198065
 
 
 
1badf2f
0198065
 
 
6e0b698
0198065
 
 
 
 
 
9a8819b
0198065
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Install necessary libraries
#!pip install PyPDF2 transformers torch accelerate streamlit

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import PyPDF2
import streamlit as st

# Function to extract text from PDF
def extract_text_from_pdf(uploaded_file):
    pdf_text = ""
    reader = PyPDF2.PdfReader(uploaded_file)
    for page_num in range(len(reader.pages)):
        page = reader.pages[page_num]
        pdf_text += page.extract_text()
    return pdf_text

# Initialize the tokenizer and model on CPU first
tokenizer = AutoTokenizer.from_pretrained("ricepaper/vi-gemma-2b-RAG")

model = AutoModelForCausalLM.from_pretrained(
    "ricepaper/vi-gemma-2b-RAG",
    torch_dtype=torch.bfloat16
)

# Move model to GPU if available
if torch.cuda.is_available():
    model.to("cuda")

# Define the prompt format for the model
prompt = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""

# Function to generate answer based on query and context
def generate_answer(context, query):
    input_text = prompt.format(context, query, "")
    input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)

    # Use GPU for input ids if available
    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")

    # Generate text using the model
    outputs = model.generate(
        **input_ids,
        max_new_tokens=500,
        no_repeat_ngram_size=5,
    )

    # Decode and print the results
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

# Streamlit App
st.title("RAG-Based PDF Question Answering Application")

# Upload PDF
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")

if uploaded_file is not None:
    # Extract text from the uploaded PDF
    pdf_text = extract_text_from_pdf(uploaded_file)

    st.write("Extracted text from PDF:")
    st.text_area("PDF Content", pdf_text, height=200)

    # User inputs their question
    query = st.text_input("Enter your question about the PDF content:")

    if st.button("Get Answer"):
        if query.strip() != "":
            # Generate answer based on extracted PDF text and the query
            answer = generate_answer(pdf_text, query)
            st.write("Answer:", answer)
        else:
            st.warning("Please enter a question.")
else:
    st.info("Please upload a PDF file to get started.")