Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio_pdf import PDF | |
from qdrant_client import models, QdrantClient | |
from sentence_transformers import SentenceTransformer | |
from PyPDF2 import PdfReader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
# from langchain.llms import LlamaCpp | |
from langchain.vectorstores import Qdrant | |
from qdrant_client.http import models | |
# from langchain.llms import CTransformers | |
from ctransformers import AutoModelForCausalLM | |
# loading the embedding model - | |
encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1') | |
print("embedding model loaded.............................") | |
print("####################################################") | |
# loading the LLM | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
print("loading the LLM......................................") | |
# llm = LlamaCpp( | |
# model_path="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q8_0.gguf", | |
# n_ctx=2048, | |
# f16_kv=True, # MUST set to True, otherwise you will run into problem after a couple of calls | |
# callback_manager=callback_manager, | |
# verbose=True, | |
# ) | |
llm = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-Chat-GGUF", | |
model_file="llama-2-7b-chat.Q3_K_S.gguf", | |
model_type="llama", | |
temperature = 0.2, | |
repetition_penalty = 1.5, | |
max_new_tokens = 300, | |
) | |
print("LLM loaded........................................") | |
print("################################################################") | |
# def get_chunks(text): | |
# text_splitter = RecursiveCharacterTextSplitter( | |
# # seperator = "\n", | |
# chunk_size = 250, | |
# chunk_overlap = 50, | |
# length_function = len, | |
# ) | |
# chunks = text_splitter.split_text(text) | |
# return chunks | |
# pdf_path = './100 Weird Facts About the Human Body.pdf' | |
# reader = PdfReader(pdf_path) | |
# text = "" | |
# num_of_pages = len(reader.pages) | |
# for page in range(num_of_pages): | |
# current_page = reader.pages[page] | |
# text += current_page.extract_text() | |
# chunks = get_chunks(text) | |
# print(chunks) | |
# print("Chunks are ready.....................................") | |
# print("######################################################") | |
# client = QdrantClient(path = "./db") | |
# print("db created................................................") | |
# print("#####################################################################") | |
# client.recreate_collection( | |
# collection_name="my_facts", | |
# vectors_config=models.VectorParams( | |
# size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model | |
# distance=models.Distance.COSINE, | |
# ), | |
# ) | |
# print("Collection created........................................") | |
# print("#########################################################") | |
# li = [] | |
# for i in range(len(chunks)): | |
# li.append(i) | |
# dic = zip(li, chunks) | |
# dic= dict(dic) | |
# client.upload_records( | |
# collection_name="my_facts", | |
# records=[ | |
# models.Record( | |
# id=idx, | |
# vector=encoder.encode(dic[idx]).tolist(), | |
# payload= {dic[idx][:5] : dic[idx]} | |
# ) for idx in dic.keys() | |
# ], | |
# ) | |
# print("Records uploaded........................................") | |
# print("###########################################################") | |
def chat(file, question): | |
def get_chunks(text): | |
text_splitter = RecursiveCharacterTextSplitter( | |
# seperator = "\n", | |
chunk_size = 250, | |
chunk_overlap = 50, | |
length_function = len, | |
) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
pdf_path = file | |
reader = PdfReader(pdf_path) | |
text = "" | |
num_of_pages = len(reader.pages) | |
for page in range(num_of_pages): | |
current_page = reader.pages[page] | |
text += current_page.extract_text() | |
chunks = get_chunks(text) | |
# print(chunks) | |
# print("Chunks are ready.....................................") | |
# print("######################################################") | |
client = QdrantClient(path = "./db") | |
# print("db created................................................") | |
# print("#####################################################################") | |
client.recreate_collection( | |
collection_name="my_facts", | |
vectors_config=models.VectorParams( | |
size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model | |
distance=models.Distance.COSINE, | |
), | |
) | |
# print("Collection created........................................") | |
# print("#########################################################") | |
li = [] | |
for i in range(len(chunks)): | |
li.append(i) | |
dic = zip(li, chunks) | |
dic= dict(dic) | |
client.upload_records( | |
collection_name="my_facts", | |
records=[ | |
models.Record( | |
id=idx, | |
vector=encoder.encode(dic[idx]).tolist(), | |
payload= {dic[idx][:5] : dic[idx]} | |
) for idx in dic.keys() | |
], | |
) | |
# print("Records uploaded........................................") | |
# print("###########################################################") | |
hits = client.search( | |
collection_name="my_facts", | |
query_vector=encoder.encode(question).tolist(), | |
limit=3 | |
) | |
context = [] | |
for hit in hits: | |
context.append(list(hit.payload.values())[0]) | |
context = context[0] + context[1] + context[2] | |
system_prompt = """You are a helpful assistant, you will use the provided context to answer user questions. | |
Read the given context before answering questions and think step by step. If you can not answer a user question based on | |
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question.""" | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS | |
instruction = f""" | |
Context: {context} | |
User: {question}""" | |
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST | |
result = llm(prompt_template) | |
return result | |
screen = gr.Interface( | |
fn = chat, | |
inputs = [PDF(label="Upload a PDF", interactive=True), gr.Textbox(lines = 10, placeholder = "Enter your question here π")], | |
outputs = gr.Textbox(lines = 10, placeholder = "Your answer will be here soon π"), | |
title="Q&A with PDF π©π»βπ»πβπ»π‘", | |
description="This app facilitates a conversation with PDFs available on https://www.delo.si/assets/media/other/20110728/100%20Weird%20Facts%20About%20the%20Human%20Body.pdfπ‘", | |
theme="soft", | |
# examples=["Hello", "what is the speed of human nerve impulses?"], | |
) | |
screen.launch() |