Spaces:
Runtime error
Runtime error
import gradio as gr | |
from llama_index.document_loaders import PDFMinerLoader | |
from llama_index.text_splitter import CharacterTextSplitter | |
from llama_index import LlamaIndex | |
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import uuid | |
import os | |
model_name = 'google/flan-t5-base' | |
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print('flan read') | |
ST_name = 'sentence-transformers/sentence-t5-base' | |
st_model = SentenceTransformer(ST_name) | |
print('sentence read') | |
index_path = "llama_index_test" | |
llama_index = LlamaIndex(index_path) | |
def get_context(query_text): | |
query_emb = st_model.encode(query_text) | |
query_response = llama_index.query(query_emb.tolist(), k=4) | |
context = query_response[0][0] | |
context = context.replace('\n', ' ').replace(' ', ' ') | |
return context | |
def local_query(query, context): | |
t5query = """Using the available context, please answer the question. | |
If you aren't sure please say i don't know. | |
Context: {} | |
Question: {} | |
""".format(context, query) | |
inputs = tokenizer(t5query, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=20) | |
return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
def run_query(btn, history, query): | |
context = get_context(query) | |
print('calling local query') | |
result = local_query(query, context) | |
print('printing result after call back') | |
print(result) | |
history.append((query, str(result[0]))) | |
print('printing history') | |
print(history) | |
return history, "" | |
def upload_pdf(file): | |
try: | |
if file is not None: | |
global llama_index | |
file_name = file.name | |
loader = PDFMinerLoader(file_name) | |
doc = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
texts = text_splitter.split_documents(doc) | |
texts = [i.page_content for i in texts] | |
doc_emb = st_model.encode(texts) | |
ids = [str(uuid.uuid1()) for _ in doc_emb] | |
llama_index.add(doc_emb.tolist(), texts, ids) | |
return 'Successfully uploaded!' | |
else: | |
return "No file uploaded." | |
except Exception as e: | |
return f"An error occurred: {e}" | |
with gr.Blocks() as demo: | |
btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"]) | |
output = gr.Textbox(label="Output Box") | |
chatbot = gr.Chatbot(height=240) | |
with gr.Row(): | |
with gr.Column(scale=0.70): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter a question", | |
) | |
# Event handler for uploading a PDF | |
btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output]) | |
txt.submit(run_query, [btn, chatbot, txt], [chatbot, txt]) | |
gr.close_all() | |
demo.queue().launch() | |