import streamlit as st |
import torch |
import bitsandbytes |
import accelerate |
import scipy |
from PIL import Image |
import torch.nn as nn |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration |
from my_model.object_detection import detect_and_draw_objects |
from my_model.captioner.image_captioning import get_caption |
from my_model.utilities import free_gpu_resources |
def answer_question(image, question, model, processor): |
image = Image.open(image) |
inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16) |
if isinstance(model, torch.nn.DataParallel): |
out = model.module.generate(**inputs, max_length=100, min_length=20) |
else: |
out = model.generate(**inputs, max_length=100, min_length=20) |
answer = processor.decode(out[0], skip_special_tokens=True).strip() |
return answer |
st.sidebar.title("Navigation") |
selection = st.sidebar.radio("Go to", ["Home", "Dataset Analysis", "Evaluation Results", "Run Inference", "Dissertation Report", ]) |
if selection == "Home": |
st.title("MultiModal Learning for Knowledg-Based Visual Question Answering") |
st.write("Home page content goes here...") |
elif selection == "Dissertation Report": |
st.title("Dissertation Report") |
st.write("Click the link below to view the PDF.") |
st.download_button( |
label="Download PDF", |
data=open("Files/Dissertation Report.pdf", "rb"), |
file_name="example.pdf", |
mime="application/octet-stream" |
) |
elif selection == "Evaluation Results": |
st.title("Evaluation Results") |
st.write("This is a Place Holder until the contents are uploaded.") |
elif selection == "Dataset Analysis": |
st.title("OK-VQA Dataset Analysis") |
st.write("This is a Place Holder until the contents are uploaded.") |
elif selection == "Run Inference": |
st.title("Run Inference") |
st.write("This page allows you to run the space for inference.") |
user_input = st.text_input("Enter your text here...") |
if st.button("Run"): |
pass |
st.title("Image Question Answering") |
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) |
question = st.text_input("Enter your question about the image:") |
if st.button('Generate Caption'): |
free_gpu_resources() |
if image is not None: |
st.image(image, use_column_width=True) |
caption = get_caption(image) |
st.write(caption) |
free_gpu_resources() |
else: |
st.write("Please upload an image and enter a question.") |
if st.button("Get Answer"): |
if image is not None and question: |
st.image(image, use_column_width=True) |
model, processor = load_caption_model() |
answer = answer_question(image, question, model, processor) |
st.write(answer) |
else: |
st.write("Please upload an image and enter a question.") |
st.sidebar.title("Object Detection") |
detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"]) |
threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4) |
detect_button = st.sidebar.button("Detect Objects") |
def perform_object_detection(image, model_name, threshold): |
""" |
Perform object detection on the given image using the specified model and threshold. |
Args: |
image (PIL.Image): The image on which to perform object detection. |
model_name (str): The name of the object detection model to use. |
threshold (float): The threshold for object detection. |
Returns: |
PIL.Image, str: The image with drawn bounding boxes and a string of detected objects. |
""" |
processed_image, detected_objects = detect_and_draw_objects(image, model_name, threshold) |
return processed_image, detected_objects |
if detect_button: |
if image is not None: |
try: |
image = Image.open(image) |
st.image(image, use_column_width=True, caption="Original Image") |
processed_image, detected_objects = perform_object_detection(image, detect_model, threshold) |
st.image(processed_image, use_column_width=True, caption="Image with Detected Objects") |
st.write(detected_objects) |
except Exception as e: |
st.error(f"Error loading image: {e}") |
else: |
st.write("Please upload an image for object detection.") |