File size: 3,052 Bytes
b7a8e2a
72865cb
 
 
 
 
 
 
b7a8e2a
72865cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from llama_index.document_loaders import PDFMinerLoader
from llama_index.text_splitter import CharacterTextSplitter
from llama_index import LlamaIndex
from transformers import T5ForConditionalGeneration, AutoTokenizer
from sentence_transformers import SentenceTransformer
import uuid
import os

model_name = 'google/flan-t5-base'
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print('flan read')

ST_name = 'sentence-transformers/sentence-t5-base'
st_model = SentenceTransformer(ST_name)
print('sentence read')

index_path = "llama_index_test"
llama_index = LlamaIndex(index_path)

def get_context(query_text):
    query_emb = st_model.encode(query_text)
    query_response = llama_index.query(query_emb.tolist(), k=4)
    context = query_response[0][0]
    context = context.replace('\n', ' ').replace('  ', ' ')
    return context

def local_query(query, context):
    t5query = """Using the available context, please answer the question. 
    If you aren't sure please say i don't know.
    Context: {}
    Question: {}
    """.format(context, query)

    inputs = tokenizer(t5query, return_tensors="pt")

    outputs = model.generate(**inputs, max_new_tokens=20)
 
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

def run_query(btn, history, query):
    context = get_context(query)
    
    print('calling local query')
    result = local_query(query, context)

    print('printing result after call back')
    print(result)

    history.append((query, str(result[0])))
        
    print('printing history')
    print(history)
    return history, ""

def upload_pdf(file):
    try:
        if file is not None: 
            global llama_index
            file_name = file.name 
   
            loader = PDFMinerLoader(file_name)
            doc = loader.load()
        
            text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
            texts = text_splitter.split_documents(doc)
        
            texts = [i.page_content for i in texts]
        
            doc_emb = st_model.encode(texts)
        
            ids = [str(uuid.uuid1()) for _ in doc_emb]
            
            llama_index.add(doc_emb.tolist(), texts, ids)

            return 'Successfully uploaded!'
        else:
            return "No file uploaded."

    except Exception as e:
        return f"An error occurred: {e}"

with gr.Blocks() as demo:  
    btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"])
    output = gr.Textbox(label="Output Box")
    chatbot = gr.Chatbot(height=240)
    
    with gr.Row():
        with gr.Column(scale=0.70):
            txt = gr.Textbox(
                show_label=False,
                placeholder="Enter a question",
            ) 

    # Event handler for uploading a PDF
    btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output])
    txt.submit(run_query, [btn, chatbot, txt], [chatbot, txt])

gr.close_all()
demo.queue().launch()