1.5_Pint_RAG / app.py
gabrielchua's picture
Update app.py
19cb2b8 verified
raw
history blame contribute delete
No virus
2.35 kB
import gradio as gr
import PyPDF2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Initialize the model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
"pints-ai/1.5-Pints-2k-v0.1",
device_map=device,
# attn_implementation="flash_attention_2"
)
tokenizer = AutoTokenizer.from_pretrained("pints-ai/1.5-Pints-2k-v0.1")
def extract_pdf_content(file):
pdf_reader = PyPDF2.PdfReader(file)
content = ""
for page in pdf_reader.pages:
content += page.extract_text() + "\n"
return content
def chat(message, history, pdf_content):
# Construct the full conversation history
full_history = [
{"role": "system", "content": f"You are an AI assistant that follows instruction extremely well. Help as much as you can. Use the following information from the uploaded PDF as context: {pdf_content}"}
] + history + [{"role": "user", "content": message}]
text = tokenizer.apply_chat_template(
full_history,
tokenize=False,
add_generation_prompt=True
)
input = tokenizer([text], return_tensors="pt").to(device)
generated_ids = model.generate(
input.input_ids,
max_new_tokens=512
)
input_length = len(input.input_ids[0])
response = tokenizer.decode(generated_ids[0][input_length:])
# Update the history with the new message and response
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return history, [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
def process_pdf(file):
if file is None:
return "Please upload a PDF file.", ""
content = extract_pdf_content(file)
return f"PDF content extracted. Length: {len(content)} characters.", content
with gr.Blocks() as demo:
pdf_content = gr.State("")
chat_history = gr.State([])
with gr.Row():
pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
pdf_output = gr.Textbox(label="PDF Processing Output")
chatbot = gr.Chatbot()
msg = gr.Textbox()
pdf_upload.upload(process_pdf, pdf_upload, [pdf_output, pdf_content])
msg.submit(chat, [msg, chat_history, pdf_content], [chat_history, chatbot])
demo.launch()