|
import base64 |
|
import io |
|
import json |
|
import os |
|
import string |
|
from typing import Any, Dict, List |
|
|
|
import chromadb |
|
import google.generativeai as palm |
|
import pandas as pd |
|
import requests |
|
import streamlit as st |
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
from langchain.text_splitter import ( |
|
RecursiveCharacterTextSplitter, |
|
SentenceTransformersTokenTextSplitter, |
|
) |
|
from PIL import Image, ImageDraw, ImageFont |
|
from pypdf import PdfReader |
|
from transformers import pipeline |
|
|
|
from utils.cnn_transformer import * |
|
from utils.helpers import * |
|
|
|
|
|
api_key = st.secrets["PALM_API_KEY"] |
|
palm.configure(api_key=api_key) |
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
st.title("Generative AI Demo on Camera Input/Image/PDF 💻") |
|
|
|
|
|
input_method = st.sidebar.selectbox( |
|
"Choose input method:", ["Camera", "Upload Image", "Upload PDF"] |
|
) |
|
|
|
image, uploaded_file = None, None |
|
if input_method == "Camera": |
|
|
|
image = st.sidebar.camera_input("Take a picture 📸") |
|
elif input_method == "Upload Image": |
|
|
|
image = st.sidebar.file_uploader("Upload a JPG image", type=["jpg"]) |
|
elif input_method == "Upload PDF": |
|
|
|
uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf") |
|
|
|
|
|
st.sidebar.markdown( |
|
""" |
|
# 🌟 How to Use the App 🌟 |
|
""" |
|
) |
|
with st.sidebar: |
|
with st.expander("Show/Hide"): |
|
st.markdown( |
|
""" |
|
1) **🌈 User Input Magic**: |
|
- 📸 **Camera Snap**: Tap to capture a moment with your device's camera. Say cheese! |
|
- 🖼️ **Image Upload Extravaganza**: Got a cool pic? Upload it from your computer and let the magic begin! |
|
- 📄 **PDF Adventure**: Use gen AI as ctrl+F to search information on any PDF, like opening a treasure chest of information! |
|
- 🧐 **YOLO Algorithm**: Wanna detect the object in the image? Use our object detection algorithm to see if the objects can be detected. |
|
|
|
2) **🤖 AI Interaction Wonderland**: |
|
- 🌟 **Gemini's AI**: Google's Gemini AI is your companion, ready to dive deep into your uploads. |
|
- 🌐 **Chroma Database**: As you upload, we're crafting a colorful Chroma database in our secret lab, making your interaction even more awesome! |
|
|
|
3) **💬 Chit-Chat with AI Post-Upload**: |
|
- 🌍 Once your content is up in the app, ask away! Any question, any time. |
|
- 💡 Light up the conversation with Gemini AI. It is like having a chat with a wise wizard from the digital realm! |
|
|
|
""" |
|
) |
|
st.sidebar.markdown( |
|
""" |
|
Enjoy exploring and have fun! App URL [here](https://huggingface.co/spaces/eagle0504/IDP-Demo)!😄🎉 |
|
""" |
|
) |
|
|
|
if image is not None: |
|
|
|
with st.expander("Expand/collapse the uploaded image:"): |
|
st.image(image, caption="Captured Image", use_column_width=True) |
|
|
|
|
|
pil_image = Image.open(image) |
|
resized_image = resize_image(pil_image) |
|
|
|
|
|
image_base64 = convert_image_to_base64(resized_image) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
if input_method == "Upload Image": |
|
with col1: |
|
st.markdown( |
|
"# OCR (Optical Character Recognition) - computer vision technology to locate letters/numbers from image." |
|
) |
|
st.success("Running textract!") |
|
url = "https://2tsig211e0.execute-api.us-east-1.amazonaws.com/my_textract" |
|
payload = {"image": image_base64} |
|
result_dict = post_request_and_parse_response(url, payload) |
|
output_data = extract_line_items(result_dict) |
|
df = pd.DataFrame(output_data) |
|
|
|
|
|
with st.expander("Show/Hide Raw Json"): |
|
st.write(result_dict) |
|
|
|
|
|
with st.expander("Show/Hide Table"): |
|
st.table(df) |
|
|
|
|
|
st.success("Bounding boxes drawn!") |
|
with st.expander("Show/Hide Annotation"): |
|
try: |
|
image = Image.open(image) |
|
|
|
|
|
image_with_boxes = draw_bounding_boxes_for_textract( |
|
image.copy(), result_dict |
|
) |
|
|
|
|
|
st.image( |
|
image_with_boxes, |
|
caption="Annotated Image", |
|
use_column_width=True, |
|
) |
|
|
|
except: |
|
st.warning("Check textract output!") |
|
|
|
if api_key: |
|
with col2: |
|
|
|
st.markdown( |
|
"# Gemini (Generative AI) - read the image content in a general form" |
|
) |
|
st.success("Running Gemini!") |
|
with st.spinner("Wait for it..."): |
|
response = call_gemini_api(image_base64, api_key) |
|
|
|
with st.expander("Raw output from Gemini"): |
|
st.write(response) |
|
|
|
try: |
|
text_from_response = response["candidates"][0]["content"]["parts"][0][ |
|
"text" |
|
] |
|
with st.spinner("Wait for it..."): |
|
st.write(text_from_response) |
|
except: |
|
st.warning("Please check the Gemini API as we do not have response from it.") |
|
|
|
|
|
try: |
|
st.sidebar.success( |
|
"Check the box if you want to see a sample retrieved (we had a template of keys, in practice this depends on the stakeholders) information to download (only use this if this is a document-based task)! 👇" |
|
) |
|
use_retrieval_tech = st.sidebar.checkbox( |
|
"Retrieve information!", |
|
value=False, |
|
) |
|
if use_retrieval_tech: |
|
st.markdown( |
|
"# Information Retrieval - use Gemini to extract the values for the keys required by the stakeholders" |
|
) |
|
with st.spinner("Processing csv to download..."): |
|
try: |
|
keys = ["First Name", "Last Name", "Policy Number"] |
|
values = [] |
|
for k in keys: |
|
updated_text_from_response = call_gemini_api( |
|
image_base64, |
|
api_key, |
|
prompt=f""" |
|
What is {k} in this document? Just answer the question directly with a word or two, don't say a complete sentence. |
|
|
|
If there is any special characters, rewrite it w |
|
""", |
|
) |
|
value = updated_text_from_response["candidates"][0][ |
|
"content" |
|
]["parts"][0]["text"] |
|
values.append(value) |
|
|
|
|
|
sample_payload_output = pd.DataFrame( |
|
{"Key": keys, "Values": values} |
|
) |
|
|
|
|
|
with st.expander("Inspect table (before download)"): |
|
st.table(sample_payload_output) |
|
|
|
|
|
csv = sample_payload_output.to_csv(index=False) |
|
|
|
|
|
csv = sample_payload_output.to_csv(index=False).encode( |
|
"utf-8" |
|
) |
|
|
|
|
|
st.download_button( |
|
label="Download data as CSV", |
|
data=csv, |
|
file_name="data.csv", |
|
mime="text/csv", |
|
) |
|
except: |
|
st.warning("Please verify document source.") |
|
|
|
st.sidebar.success( |
|
"Check the box if you want to chat with Gemini (do this if you want Gemini to answwer your questions)! 👇" |
|
) |
|
use_gemini_to_chat = st.sidebar.checkbox( |
|
"Chat with Gemini (about the data)!", |
|
value=False, |
|
) |
|
if use_gemini_to_chat: |
|
|
|
input_prompt = st.text_input( |
|
"Type your question here:", |
|
) |
|
|
|
|
|
if input_prompt: |
|
updated_text_from_response = call_gemini_api( |
|
image_base64, api_key, prompt=input_prompt |
|
) |
|
|
|
if updated_text_from_response is not None: |
|
|
|
updated_ans = updated_text_from_response["candidates"][0][ |
|
"content" |
|
]["parts"][0]["text"] |
|
with st.spinner("Wait for it..."): |
|
st.write(f"Gemini: {updated_ans}") |
|
else: |
|
st.warning("Check gemini's API.") |
|
|
|
except: |
|
st.write("No response from API.") |
|
else: |
|
st.write("API Key is not set. Please set the API Key.") |
|
|
|
|
|
if image is not None: |
|
st.sidebar.success( |
|
"Check the following box to run YOLO algorithm if desired (only do this if the task at hand is an object detection task)! 👇" |
|
) |
|
use_yolo = st.sidebar.checkbox("Use YOLO!", value=False) |
|
if use_yolo: |
|
yolo_option = st.selectbox( |
|
"Which YOLO algorithm would you like?", |
|
("hustvl/yolos-small", "eagle0504/detr-finetuned-balloon-v2"), |
|
) |
|
else: |
|
yolo_option = None |
|
|
|
|
|
if yolo_option == "hustvl/yolos-small": |
|
yolo_pipe = pipeline("object-detection", model="hustvl/yolos-small") |
|
elif yolo_option == "eagle0504/detr-finetuned-balloon-v2": |
|
yolo_pipe = pipeline( |
|
"object-detection", model="eagle0504/detr-finetuned-balloon-v2" |
|
) |
|
else: |
|
yolo_pipe = None |
|
|
|
if yolo_pipe is not None: |
|
|
|
image = Image.open(image) |
|
with st.spinner("Wait for it..."): |
|
st.success("Running YOLO algorithm!") |
|
predictions = yolo_pipe(image) |
|
st.success("YOLO running successfully.") |
|
|
|
|
|
image_with_boxes = draw_boxes(image.copy(), predictions) |
|
st.success("Bounding boxes drawn.") |
|
|
|
|
|
st.image(image_with_boxes, caption="Annotated Image", use_column_width=True) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
st.sidebar.success("Note: 1 Token ~ 4 Characters.") |
|
token_size = st.sidebar.slider( |
|
"Select a token size (when we scrape the document)", 5, 150, 45 |
|
) |
|
top_n_content = st.sidebar.slider( |
|
"Select top n content(s) you want to display as reference", 3, 30, 5 |
|
) |
|
|
|
|
|
bytes_data = uploaded_file.getvalue() |
|
st.success("Your PDF is uploaded successfully.") |
|
|
|
|
|
file_name = uploaded_file.name |
|
|
|
|
|
with open(file_name, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
|
|
|
|
|
|
|
|
reader = PdfReader(file_name) |
|
pdf_texts = [p.extract_text().strip() for p in reader.pages] |
|
|
|
|
|
pdf_texts = [text for text in pdf_texts if text] |
|
st.success("PDF extracted successfully.") |
|
|
|
|
|
character_splitter = RecursiveCharacterTextSplitter( |
|
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0 |
|
) |
|
character_split_texts = character_splitter.split_text("\n\n".join(pdf_texts)) |
|
st.success("Texts splitted successfully.") |
|
|
|
|
|
st.warning("Start tokenzing ...") |
|
token_splitter = SentenceTransformersTokenTextSplitter( |
|
chunk_overlap=5, tokens_per_chunk=token_size |
|
) |
|
token_split_texts = [] |
|
for text in character_split_texts: |
|
token_split_texts += token_splitter.split_text(text) |
|
st.success("Tokenized successfully.") |
|
|
|
|
|
random_number: int = np.random.randint(low=1e9, high=1e10) |
|
|
|
|
|
random_string: str = "".join( |
|
np.random.choice(list(string.ascii_uppercase + string.digits), size=20) |
|
) |
|
|
|
|
|
combined_string: str = f"{random_number}{random_string}" |
|
|
|
|
|
embedding_function = SentenceTransformerEmbeddingFunction() |
|
chroma_client = chromadb.Client() |
|
chroma_collection = chroma_client.create_collection( |
|
combined_string, embedding_function=embedding_function |
|
) |
|
ids = [str(i) for i in range(len(token_split_texts))] |
|
chroma_collection.add(ids=ids, documents=token_split_texts) |
|
st.success("Vector database loaded successfully.") |
|
|
|
|
|
query = st.text_input("Ask me anything!", "What is the document about?") |
|
results = chroma_collection.query(query_texts=[query], n_results=top_n_content) |
|
retrieved_documents = results["documents"][0] |
|
results_as_table = pd.DataFrame( |
|
{ |
|
"ids": results["ids"][0], |
|
"documents": results["documents"][0], |
|
"distances": results["distances"][0], |
|
} |
|
) |
|
|
|
|
|
output = rag(query=query, retrieved_documents=retrieved_documents) |
|
st.write(output) |
|
st.success( |
|
"Please see where the chatbot got the information from the document below.👇" |
|
) |
|
with st.expander("Raw query outputs:"): |
|
st.write(results) |
|
with st.expander("Processed tabular form query outputs:"): |
|
st.table(results_as_table) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|