ktllc's picture
Update app.py
72865cb
import gradio as gr
from llama_index.document_loaders import PDFMinerLoader
from llama_index.text_splitter import CharacterTextSplitter
from llama_index import LlamaIndex
from transformers import T5ForConditionalGeneration, AutoTokenizer
from sentence_transformers import SentenceTransformer
import uuid
import os
model_name = 'google/flan-t5-base'
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print('flan read')
ST_name = 'sentence-transformers/sentence-t5-base'
st_model = SentenceTransformer(ST_name)
print('sentence read')
index_path = "llama_index_test"
llama_index = LlamaIndex(index_path)
def get_context(query_text):
query_emb = st_model.encode(query_text)
query_response = llama_index.query(query_emb.tolist(), k=4)
context = query_response[0][0]
context = context.replace('\n', ' ').replace(' ', ' ')
return context
def local_query(query, context):
t5query = """Using the available context, please answer the question.
If you aren't sure please say i don't know.
Context: {}
Question: {}
""".format(context, query)
inputs = tokenizer(t5query, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=20)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
def run_query(btn, history, query):
context = get_context(query)
print('calling local query')
result = local_query(query, context)
print('printing result after call back')
print(result)
history.append((query, str(result[0])))
print('printing history')
print(history)
return history, ""
def upload_pdf(file):
try:
if file is not None:
global llama_index
file_name = file.name
loader = PDFMinerLoader(file_name)
doc = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(doc)
texts = [i.page_content for i in texts]
doc_emb = st_model.encode(texts)
ids = [str(uuid.uuid1()) for _ in doc_emb]
llama_index.add(doc_emb.tolist(), texts, ids)
return 'Successfully uploaded!'
else:
return "No file uploaded."
except Exception as e:
return f"An error occurred: {e}"
with gr.Blocks() as demo:
btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"])
output = gr.Textbox(label="Output Box")
chatbot = gr.Chatbot(height=240)
with gr.Row():
with gr.Column(scale=0.70):
txt = gr.Textbox(
show_label=False,
placeholder="Enter a question",
)
# Event handler for uploading a PDF
btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output])
txt.submit(run_query, [btn, chatbot, txt], [chatbot, txt])
gr.close_all()
demo.queue().launch()