import base64 import chromadb import gc import gradio as gr import io import numpy as np import os import pandas as pd import pymupdf from pypdf import PdfReader import spaces import torch from PIL import Image from chromadb.utils import embedding_functions from chromadb.utils.data_loaders import ImageLoader from doctr.io import DocumentFile from doctr.models import ocr_predictor from gradio.themes.utils import sizes from langchain import PromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.llms import HuggingFaceEndpoint from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor from utils import * def result_to_text(result, as_text=False) -> str or list: full_doc = [] for _, page in enumerate(result.pages, start=1): text = "" for block in page.blocks: text += "\n\t" for line in block.lines: for word in line.words: text += word.value + " " full_doc.append(clean_text(text) + "\n\n") return "\n".join(full_doc) if as_text else full_doc ocr_model = ocr_predictor( "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True, ) if torch.cuda.is_available(): processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") vision_model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True, load_in_4bit=True, ) @spaces.GPU() def get_image_description(image): torch.cuda.empty_cache() gc.collect() # n = len(prompt) prompt = "[INST] \nDescribe the image in a sentence [/INST]" inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") output = vision_model.generate(**inputs, max_new_tokens=100) return processor.decode(output[0], skip_special_tokens=True) CSS = """ #table_col {background-color: rgb(33, 41, 54);} """ # def get_vectordb(text, images, tables): def get_vectordb(text, images, img_doc_files): client = chromadb.EphemeralClient() loader = ImageLoader() sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="multi-qa-mpnet-base-dot-v1" ) if "text_db" in [i.name for i in client.list_collections()]: client.delete_collection("text_db") if "image_db" in [i.name for i in client.list_collections()]: client.delete_collection("image_db") text_collection = client.get_or_create_collection( name="text_db", embedding_function=sentence_transformer_ef, data_loader=loader, ) image_collection = client.get_or_create_collection( name="image_db", embedding_function=sentence_transformer_ef, data_loader=loader, metadata={"hnsw:space": "cosine"}, ) descs = [] for i in range(len(images)): try: descs.append(img_doc_files[i] + "\n" + get_image_description(images[i])) except: descs.append("Could not generate image description due to some error") print(descs[-1]) print() # image_descriptions = get_image_descriptions(images) image_dict = [{"image": image_to_bytes(img)} for img in images] if len(images) > 0: image_collection.add( ids=[str(i) for i in range(len(images))], documents=descs, metadatas=image_dict, ) splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=10, ) if len(text.replace(" ", "").replace("\n", "")) == 0: gr.Error("No text found in documents") else: docs = splitter.create_documents([text]) doc_texts = [i.page_content for i in docs] text_collection.add( ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts ) return client def extract_only_text(reader): text = "" for _, page in enumerate(reader.pages): text = page.extract_text() return text.strip() def extract_data_from_pdfs( docs, session, include_images, do_ocr, progress=gr.Progress() ): if len(docs) == 0: raise gr.Error("No documents to process") progress(0, "Extracting Images") # images = extract_images(docs) progress(0.25, "Extracting Text") all_text = "" images = [] img_docs = [] for doc in docs: if do_ocr == "Get Text With OCR": pdf_doc = DocumentFile.from_pdf(doc) result = ocr_model(pdf_doc) all_text += result_to_text(result, as_text=True) + "\n\n" else: reader = PdfReader(doc) all_text += extract_only_text(reader) + "\n\n" if include_images == "Include Images": imgs = extract_images([doc]) images.extend(imgs) img_docs.extend([doc.split("/")[-1] for _ in range(len(imgs))]) progress( 0.6, "Generating image descriptions and inserting everything into vectorDB" ) vectordb = get_vectordb(all_text, images, img_docs) progress(1, "Completed") session["processed"] = True return ( vectordb, session, gr.Row(visible=True), all_text[:2000] + "...", # display, images[:2], "

Completed

", # image_descriptions ) sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name="multi-qa-mpnet-base-dot-v1" ) def conversation( vectordb_client, msg, num_context, img_context, history, temperature, max_new_tokens, hf_token, model_path, ): if hf_token.strip() != "" and model_path.strip() != "": llm = HuggingFaceEndpoint( repo_id=model_path, temperature=temperature, max_new_tokens=max_new_tokens, huggingfacehub_api_token=hf_token, ) else: llm = HuggingFaceEndpoint( repo_id="meta-llama/Meta-Llama-3-8B-Instruct", temperature=temperature, max_new_tokens=max_new_tokens, huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"), ) text_collection = vectordb_client.get_collection( "text_db", embedding_function=sentence_transformer_ef ) image_collection = vectordb_client.get_collection( "image_db", embedding_function=sentence_transformer_ef ) results = text_collection.query( query_texts=[msg], include=["documents"], n_results=num_context )["documents"][0] similar_images = image_collection.query( query_texts=[msg], include=["metadatas", "distances", "documents"], n_results=img_context, ) img_links = [i["image"] for i in similar_images["metadatas"][0]] images_and_locs = [ Image.open(io.BytesIO(base64.b64decode(i[1]))) for i in zip(similar_images["distances"][0], img_links) ] img_desc = "\n".join(similar_images["documents"][0]) if len(img_links) == 0: img_desc = "No Images Are Provided" template = """ Context: {context} Included Images: {images} Question: {question} Answer: """ prompt = PromptTemplate(template=template, input_variables=["context", "question"]) context = "\n\n".join(results) # references = [gr.Textbox(i, visible=True, interactive=False) for i in results] response = llm(prompt.format(context=context, question=msg, images=img_desc)) return history + [(msg, response)], results, images_and_locs def check_validity_and_llm(session_states): if session_states.get("processed", False) == True: return gr.Tabs(selected=2) raise gr.Error("Please extract data first") with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo: vectordb = gr.State() doc_collection = gr.State(value=[]) session_states = gr.State(value={}) references = gr.State(value=[]) gr.Markdown( """

Multimodal PDF Chatbot

Interact With Your PDF Documents

""" ) gr.Markdown( """

Note: This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents


Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.
""" ) gr.Markdown( """
Warning: Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description
""" ) with gr.Tabs() as tabs: with gr.TabItem("Upload PDFs", id=0) as pdf_tab: with gr.Row(): with gr.Column(): documents = gr.File( file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF file/s", ) pdf_btn = gr.Button(value="Next", elem_id="button1") with gr.TabItem("Extract Data", id=1) as preprocess: with gr.Row(): with gr.Column(): back_p1 = gr.Button(value="Back") with gr.Column(): embed = gr.Button(value="Extract Data") with gr.Column(): next_p1 = gr.Button(value="Next") with gr.Row(): include_images = gr.Radio( ["Include Images", "Exclude Images"], value="Include Images", label="Include/ Exclude Images", interactive=True, ) do_ocr = gr.Radio( ["Get Text With OCR", "Get Available Text Only"], value="Get Text With OCR", label="OCR/ No OCR", interactive=True, ) with gr.Row(equal_height=True, variant="panel") as row: selected = gr.Dataframe( interactive=False, col_count=(1, "fixed"), headers=["Selected Files"], ) prog = gr.HTML( value="

Click the 'Extract' button to extract data from PDFs

" ) with gr.Accordion("See Parts of Extracted Data", open=False): with gr.Column(visible=True) as sample_data: with gr.Row(): with gr.Column(): ext_text = gr.Textbox( label="Sample Extracted Text", lines=15 ) with gr.Column(): images = gr.Gallery( label="Sample Extracted Images", columns=1, rows=2 ) with gr.TabItem("Chat", id=2) as chat_tab: with gr.Accordion("Config (Advanced) (Optional)", open=False): with gr.Row(variant="panel", equal_height=True): choice = gr.Radio( ["chromaDB"], value="chromaDB", label="Vector Database", interactive=True, ) with gr.Accordion("Use your own model (optional)", open=False): hf_token = gr.Textbox( label="HuggingFace Token", interactive=True ) model_path = gr.Textbox(label="Model Path", interactive=True) with gr.Row(variant="panel", equal_height=True): num_context = gr.Slider( label="Number of text context elements", minimum=1, maximum=20, step=1, interactive=True, value=3, ) img_context = gr.Slider( label="Number of image context elements", minimum=1, maximum=10, step=1, interactive=True, value=2, ) with gr.Row(variant="panel", equal_height=True): temp = gr.Slider( label="Temperature", minimum=0.1, maximum=1, step=0.1, interactive=True, value=0.4, ) max_tokens = gr.Slider( label="Max Tokens", minimum=10, maximum=2000, step=10, interactive=True, value=500, ) with gr.Row(): with gr.Column(): ret_images = gr.Gallery("Similar Images", columns=1, rows=2) with gr.Column(): chatbot = gr.Chatbot(height=400) with gr.Accordion("Text References", open=False): # text_context = gr.Row() @gr.render(inputs=references) def gen_refs(references): # print(references) n = len(references) for i in range(n): gr.Textbox( label=f"Reference-{i+1}", value=references[i], lines=3 ) with gr.Row(): msg = gr.Textbox( placeholder="Type your question here (e.g. 'What is this document about?')", interactive=True, container=True, ) with gr.Row(): submit_btn = gr.Button("Submit message") clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation") pdf_btn.click( fn=extract_pdfs, inputs=[documents, doc_collection], outputs=[doc_collection, tabs, selected], ) embed.click( extract_data_from_pdfs, inputs=[doc_collection, session_states, include_images, do_ocr], outputs=[ vectordb, session_states, sample_data, ext_text, images, prog, ], ) submit_btn.click( conversation, [ vectordb, msg, num_context, img_context, chatbot, temp, max_tokens, hf_token, model_path, ], [chatbot, references, ret_images], ) msg.submit( conversation, [ vectordb, msg, num_context, img_context, chatbot, temp, max_tokens, hf_token, model_path, ], [chatbot, references, ret_images], ) documents.change( lambda: "

Click the 'Extract' button to extract data from PDFs

", None, prog, ) back_p1.click(lambda: gr.Tabs(selected=0), None, tabs) next_p1.click(check_validity_and_llm, session_states, tabs) if __name__ == "__main__": demo.launch(server_port=8008)