ktllc commited on
Commit
72865cb
1 Parent(s): 32d5e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -2
app.py CHANGED
@@ -1,4 +1,101 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- iface=gr.Interface.load("models/sentence-transformers/sentence-t5-base")
4
- gr.launch(iface)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from llama_index.document_loaders import PDFMinerLoader
3
+ from llama_index.text_splitter import CharacterTextSplitter
4
+ from llama_index import LlamaIndex
5
+ from transformers import T5ForConditionalGeneration, AutoTokenizer
6
+ from sentence_transformers import SentenceTransformer
7
+ import uuid
8
+ import os
9
 
10
+ model_name = 'google/flan-t5-base'
11
+ model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ print('flan read')
14
+
15
+ ST_name = 'sentence-transformers/sentence-t5-base'
16
+ st_model = SentenceTransformer(ST_name)
17
+ print('sentence read')
18
+
19
+ index_path = "llama_index_test"
20
+ llama_index = LlamaIndex(index_path)
21
+
22
+ def get_context(query_text):
23
+ query_emb = st_model.encode(query_text)
24
+ query_response = llama_index.query(query_emb.tolist(), k=4)
25
+ context = query_response[0][0]
26
+ context = context.replace('\n', ' ').replace(' ', ' ')
27
+ return context
28
+
29
+ def local_query(query, context):
30
+ t5query = """Using the available context, please answer the question.
31
+ If you aren't sure please say i don't know.
32
+ Context: {}
33
+ Question: {}
34
+ """.format(context, query)
35
+
36
+ inputs = tokenizer(t5query, return_tensors="pt")
37
+
38
+ outputs = model.generate(**inputs, max_new_tokens=20)
39
+
40
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
41
+
42
+ def run_query(btn, history, query):
43
+ context = get_context(query)
44
+
45
+ print('calling local query')
46
+ result = local_query(query, context)
47
+
48
+ print('printing result after call back')
49
+ print(result)
50
+
51
+ history.append((query, str(result[0])))
52
+
53
+ print('printing history')
54
+ print(history)
55
+ return history, ""
56
+
57
+ def upload_pdf(file):
58
+ try:
59
+ if file is not None:
60
+ global llama_index
61
+ file_name = file.name
62
+
63
+ loader = PDFMinerLoader(file_name)
64
+ doc = loader.load()
65
+
66
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
67
+ texts = text_splitter.split_documents(doc)
68
+
69
+ texts = [i.page_content for i in texts]
70
+
71
+ doc_emb = st_model.encode(texts)
72
+
73
+ ids = [str(uuid.uuid1()) for _ in doc_emb]
74
+
75
+ llama_index.add(doc_emb.tolist(), texts, ids)
76
+
77
+ return 'Successfully uploaded!'
78
+ else:
79
+ return "No file uploaded."
80
+
81
+ except Exception as e:
82
+ return f"An error occurred: {e}"
83
+
84
+ with gr.Blocks() as demo:
85
+ btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"])
86
+ output = gr.Textbox(label="Output Box")
87
+ chatbot = gr.Chatbot(height=240)
88
+
89
+ with gr.Row():
90
+ with gr.Column(scale=0.70):
91
+ txt = gr.Textbox(
92
+ show_label=False,
93
+ placeholder="Enter a question",
94
+ )
95
+
96
+ # Event handler for uploading a PDF
97
+ btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output])
98
+ txt.submit(run_query, [btn, chatbot, txt], [chatbot, txt])
99
+
100
+ gr.close_all()
101
+ demo.queue().launch()