KB-VQA-E / app.py
m7mdal7aj's picture
Update app.py
aebc520 verified
raw
history blame
7 kB
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
import copy
from PIL import Image
import torch.nn as nn
from my_model.object_detection import detect_and_draw_objects
from my_model.captioner.image_captioning import get_caption
from my_model.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
def answer_question(caption, detected_objects_str, question, model):
answer = model.generate_answer(question, caption, detected_objects_str)
return answer
def get_caption(image):
return "Generated caption for the image"
def free_gpu_resources():
pass
# Sample images (assuming these are paths to your sample images)
sample_images = ["Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
"Files/sample4.jpg", "Files/sample5.jpg", "Files/sample6.jpg",
"Files/sample7.jpg"]
def analyze_image(image, model):
# Placeholder for your analysis function
# This function should prepare captions, detect objects, etc.
# For example:
caption = model.get_caption(image)
image_with_boxes, detected_objects_str = model.detect_objects(image)
return caption, detected_objects_str
def image_qa_app(kbvqa):
# Initialize session state for storing the current image and its Q&A history.
if 'current_image' not in st.session_state:
st.session_state['current_image'] = None
if 'qa_history' not in st.session_state:
st.session_state['qa_history'] = []
if 'analysis_done' not in st.session_state:
st.session_state['analysis_done'] = False
if 'answer_in_progress' not in st.session_state:
st.session_state['answer_in_progress'] = False
# Display sample images as clickable thumbnails
st.write("Choose from sample images:")
cols = st.columns(len(sample_images))
for idx, sample_image_path in enumerate(sample_images):
with cols[idx]:
image = Image.open(sample_image_path)
st.image(image, use_column_width=True)
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
st.session_state['current_image'] = image
st.session_state['qa_history'] = []
st.session_state['analysis_done'] = False
st.session_state['answer_in_progress'] = False
# Image uploader
uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.session_state['current_image'] = image
st.session_state['qa_history'] = []
st.session_state['analysis_done'] = False
st.session_state['answer_in_progress'] = False
# Display the current image (unaltered)
if st.session_state.get('current_image'):
st.image(st.session_state['current_image'], caption='Uploaded Image.', use_column_width=True)
# Analyze Image button
if st.session_state.get('current_image') and not st.session_state['analysis_done']:
if st.button('Analyze Image'):
# Perform analysis on the image
caption, detected_objects_str = analyze_image(st.session_state['current_image'], kbvqa)
st.session_state['analysis_done'] = True
st.session_state['processed_image'] = copy.deepcopy(st.session_state['current_image'])
# Get Answer button
if st.session_state['analysis_done'] and not st.session_state['answer_in_progress']:
question = st.text_input("Ask a question about this image:")
if st.button('Get Answer'):
st.session_state['answer_in_progress'] = True
answer = answer_question(caption, detected_objects_str, question, model=kbvqa)
st.session_state['qa_history'].append((question, answer))
# Display all Q&A
for q, a in st.session_state['qa_history']:
st.text(f"Q: {q}\nA: {a}\n")
# Reset the answer_in_progress flag after displaying the answer
if st.session_state['answer_in_progress']:
st.session_state['answer_in_progress'] = False
def run_inference():
st.title("Run Inference")
method = st.selectbox(
"Choose a method:",
["Fine-Tuned Model", "In-Context Learning (n-shots)"],
index=0 # Default to the first option
)
detection_model = st.selectbox(
"Choose a model for object detection:",
["yolov5", "detic"],
index=0 # Default to the first option
)
# Set default confidence based on the selected model
default_confidence = 0.2 if detection_model == "yolov5" else 0.4
# Slider for confidence level
confidence_level = st.slider(
"Select Detection Confidence Level",
min_value=0.1,
max_value=0.9,
value=default_confidence,
step=0.1
)
# Initialize session state for the model
if method == "Fine-Tuned Model":
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
# Button to load KBVQA models
if st.button('Load KBVQA Model'):
if st.session_state['kbvqa'] is not None:
st.write("Model already loaded.")
else:
# Call the function to load models and show progress
st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
if st.session_state['kbvqa']:
st.write("Model is ready for inference.")
if st.session_state['kbvqa']:
image_qa_app(st.session_state['kbvqa'])
else:
st.write('Model is not ready for inference yet')
# Main function
def main():
st.sidebar.title("Navigation")
selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report", "Object Detection"])
if selection == "Home":
st.title("MultiModal Learning for Knowledg-Based Visual Question Answering")
st.write("Home page content goes here...")
elif selection == "Dissertation Report":
st.title("Dissertation Report")
st.write("Click the link below to view the PDF.")
# Example to display a link to a PDF
st.download_button(
label="Download PDF",
data=open("Files/Dissertation Report.pdf", "rb"),
file_name="example.pdf",
mime="application/octet-stream"
)
elif selection == "Evaluation Results":
st.title("Evaluation Results")
st.write("This is a Place Holder until the contents are uploaded.")
elif selection == "Dataset Analysis":
st.title("OK-VQA Dataset Analysis")
st.write("This is a Place Holder until the contents are uploaded.")
elif selection == "Run Inference":
run_inference()
elif selection == "Object Detection":
run_object_detection()
if __name__ == "__main__":
main()