KB-VQA-E / my_model /KBVQA.py
m7mdal7aj's picture
Update my_model/KBVQA.py
a97003a verified
# Main script for KBVQA: Knowledge-Based Visual Question Answering Module
# This module is the central component for implementing the designed model architecture for the Knowledge-Based Visual
# Question Answering (KB-VQA) project. It integrates various sub-modules, including image captioning, object detection,
# and a fine-tuned language model, to provide a comprehensive solution for answering questions based on visual input.
# --- Description ---
# **KBVQA class**:
# The KBVQA class encapsulates the functionality needed to perform visual question answering using a combination of
# multimodal models.
# The class handles the following tasks:
# - Loading and managing a fine-tuned language model (LLaMA-2) for question answering.
# - Integrating an image captioning model to generate descriptive captions for input images.
# - Utilizing an object detection model to identify and describe objects within the images.
# - Formatting and generating prompts for the language model based on the image captions and detected objects.
# - Providing methods to analyze images and generate answers to user-provided questions.
# **prepare_kbvqa_model function**:
# - The prepare_kbvqa_model function orchestrates the loading and initialization of the KBVQA class, ensuring it is
# ready for inference.
# ---Instructions---
# **Model Preparation**:
# Use the prepare_kbvqa_model function to prepare and initialize the KBVQA system, ensuring all required models are
# loaded and ready for use.
# **Image Processing and Question Answering**:
# Use the get_caption method to generate captions for input images.
# Use the detect_objects method to identify and describe objects in the images.
# Use the generate_answer method to answer questions based on the image captions and detected objects.
# This module forms the backbone of the KB-VQA project, integrating advanced models to provide an end-to-end solution
# for visual question answering tasks.
# Ensure all dependencies are installed and the required configuration file is in place before running this script.
# The configurations for the KBVQA class are defined in the 'my_model/config/kbvqa_config.py' file.
# ---------- Please run this module to utilize the full KB-VQA functionality ----------#
# ---------- Please ensure this is run on a GPU ----------#
import streamlit as st
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Tuple, Optional
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.captioner.image_captioning import ImageCaptioningModel
from my_model.detector.object_detection import ObjectDetector
import my_model.config.kbvqa_config as config
class KBVQA:
"""
The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model.
It integrates various components such as an image captioning model, object detection model, and a fine-tuned
language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.
Attributes:
kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA.
quantization (str): The quantization setting for the model (e.g., '4bit', '8bit').
max_context_window (int): The maximum number of tokens allowed in the model's context window.
add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer.
trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer.
use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer.
low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage.
kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model.
captioner (Optional[ImageCaptioningModel]): The model used for generating image captions.
detector (Optional[ObjectDetector]): The object detection model.
detection_model (Optional[str]): The name of the object detection model.
detection_confidence (Optional[float]): The confidence threshold for object detection.
kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
access_token (str): Access token for Hugging Face API.
current_prompt_length (int): Prompt length.
Methods:
create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
load_caption_model: Loads the image captioning model.
get_caption: Generates a caption for a given image.
load_detector: Loads the object detection model.
detect_objects: Detects objects in a given image.
load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer.
all_models_loaded: Checks if all the required models are loaded.
force_reload_model: Forces a reload of all models, freeing up GPU resources.
format_prompt: Formats the prompt for the KBVQA model.
generate_answer: Generates an answer to a given question using the KBVQA model.
"""
def __init__(self) -> None:
"""
Initializes the KBVQA instance with configuration parameters.
"""
if st.session_state["method"] == "7b-Fine-Tuned Model":
self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b
elif st.session_state["method"] == "13b-Fine-Tuned Model":
self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b
self.quantization: str = config.QUANTIZATION
self.max_context_window: int = config.MAX_CONTEXT_WINDOW # set to 4,000 tokens
self.add_eos_token: bool = config.ADD_EOS_TOKEN
self.trust_remote: bool = config.TRUST_REMOTE
self.use_fast: bool = config.USE_FAST
self.low_cpu_mem_usage: bool = config.LOW_CPU_MEM_USAGE
self.kbvqa_tokenizer: Optional[AutoTokenizer] = None
self.captioner: Optional[ImageCaptioningModel] = None
self.detector: Optional[ObjectDetector] = None
self.detection_model: Optional[str] = None
self.detection_confidence: Optional[float] = None
self.kbvqa_model: Optional[AutoModelForCausalLM] = None
self.bnb_config: BitsAndBytesConfig = self.create_bnb_config()
self.access_token: str = config.HUGGINGFACE_TOKEN
self.current_prompt_length = None
def create_bnb_config(self) -> BitsAndBytesConfig:
"""
Creates a BitsAndBytes configuration based on the quantization setting.
Returns:
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
"""
if self.quantization == '4bit':
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
elif self.quantization == '8bit':
return BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_quant_type="nf4",
bnb_8bit_compute_dtype=torch.bfloat16
)
def load_caption_model(self) -> None:
"""
Loads the image captioning model into the KBVQA instance.
Returns:
None
"""
self.captioner = ImageCaptioningModel()
self.captioner.load_model()
free_gpu_resources()
def get_caption(self, img: Image.Image) -> str:
"""
Generates a caption for a given image using the image captioning model.
Args:
img (PIL.Image.Image): The image for which to generate a caption.
Returns:
str: The generated caption for the image.
"""
caption = self.captioner.generate_caption(img)
free_gpu_resources()
return caption
def load_detector(self, model: str) -> None:
"""
Loads the object detection model.
Args:
model (str): The name of the object detection model to load.
Returns:
None
"""
self.detector = ObjectDetector()
self.detector.load_model(model)
free_gpu_resources()
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
"""
Detects objects in a given image using the loaded object detection model.
Args:
img (PIL.Image.Image): The image in which to detect objects.
Returns:
tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
"""
image = self.detector.process_image(img)
free_gpu_resources()
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[
'confidence_level'])
free_gpu_resources()
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
free_gpu_resources()
return image_with_boxes, detected_objects_string
def load_fine_tuned_model(self) -> None:
"""
Loads the fine-tuned KBVQA model along with its tokenizer.
Returns:
None
"""
self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name,
device_map="auto",
low_cpu_mem_usage=True,
quantization_config=self.bnb_config,
token=self.access_token)
free_gpu_resources()
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name,
use_fast=self.use_fast,
low_cpu_mem_usage=True,
trust_remote_code=self.trust_remote,
add_eos_token=self.add_eos_token,
token=self.access_token)
free_gpu_resources()
@property
def all_models_loaded(self) -> bool:
"""
Checks if all the required models (KBVQA, captioner, detector) are loaded.
Returns:
bool: True if all models are loaded, False otherwise.
"""
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None,
caption: str = None, objects: Optional[str] = None) -> str:
"""
Formats the prompt for the KBVQA model based on the provided parameters.
This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture.
Args:
current_query (str): The current question to be answered.
history (str, optional): The history of previous interactions.
sys_prompt (str, optional): The system prompt or instructions for the model.
caption (str, optional): The caption of the image.
objects (str, optional): The detected objects in the image.
Returns:
str: The formatted prompt for the KBVQA model.
"""
# These are the special tokens designed for the model to be fine-tuned on.
B_CAP = '[CAP]'
E_CAP = '[/CAP]'
B_QES = '[QES]'
E_QES = '[/QES]'
B_OBJ = '[OBJ]'
E_OBJ = '[/OBJ]'
# These are the default special tokens of LLaMA-2 Chat Model.
B_SENT = '<s>'
E_SENT = '</s>'
B_INST = '[INST]'
E_INST = '[/INST]'
B_SYS = '<<SYS>>\n'
E_SYS = '\n<</SYS>>\n\n'
current_query = current_query.strip()
if sys_prompt is None:
sys_prompt = config.SYSTEM_PROMPT.strip()
# History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app.
if history is None:
if objects is None:
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
else:
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
else:
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
return p
@staticmethod
def trim_objects(detected_objects_str: str) -> str:
"""
Trim the last object from the detected objects string.
This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens.
Args:
detected_objects_str (str): String containing detected objects.
Returns:
str: The string with the last object removed.
"""
objects = detected_objects_str.strip().split("\n")
if len(objects) >= 1:
return "\n".join(objects[:-1])
return ""
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
"""
Generates an answer to a given question using the KBVQA model.
Args:
question (str): The question to be answered.
caption (str): The caption of the image related to the question.
detected_objects_str (str): The string representation of detected objects in the image.
Returns:
str: The generated answer to the question.
"""
free_gpu_resources()
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
self.current_prompt_length = num_tokens
trim = False # flag used to check if prompt trim is required or no.
# max_context_window is set to 4,000 tokens, refer to the config file.
if self.current_prompt_length > self.max_context_window:
trim = True
st.warning(
f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2,"
f" objects detected with low confidence will be removed one at a time until the prompt length is within the"
f" maximum context window ...")
# an object is trimmed from the bottom of the list until the overall prompt length is within the context window.
while self.current_prompt_length > self.max_context_window:
detected_objects_str = self.trim_objects(detected_objects_str)
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
if detected_objects_str == "":
break # Break if no objects are left
if trim:
st.warning(f"New prompt length is: {self.current_prompt_length}")
trim = False
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
free_gpu_resources()
input_ids = model_inputs["input_ids"]
output_ids = self.kbvqa_model.generate(input_ids)
free_gpu_resources()
index = input_ids.shape[1] # needed to avoid printing the input prompt
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)
return output_text.capitalize()
def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA:
"""
Prepares the KBVQA model for use, including loading necessary sub-models.
This serves as the main function for loading and reloading the KB-VQA model.
Args:
only_reload_detection_model (bool): If True, only the object detection model is reloaded.
force_reload (bool): If True, forces the reload of all models.
Returns:
KBVQA: An instance of the KBVQA model ready for inference.
"""
if force_reload:
free_gpu_resources()
loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!'
try:
del st.session_state['kbvqa']
free_gpu_resources()
free_gpu_resources()
except:
free_gpu_resources()
free_gpu_resources()
pass
free_gpu_resources()
else:
loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!'
free_gpu_resources()
kbvqa = KBVQA()
kbvqa.detection_model = st.session_state.detection_model
# Progress bar for model loading
with st.spinner(loading_message):
if not only_reload_detection_model:
progress_bar = st.progress(0)
kbvqa.load_detector(kbvqa.detection_model)
progress_bar.progress(33)
kbvqa.load_caption_model()
free_gpu_resources()
progress_bar.progress(75)
st.text('Almost there :)')
kbvqa.load_fine_tuned_model()
free_gpu_resources()
progress_bar.progress(100)
else:
free_gpu_resources()
progress_bar = st.progress(0)
kbvqa.load_detector(kbvqa.detection_model)
progress_bar.progress(100)
if kbvqa.all_models_loaded:
st.success('Model loaded successfully and ready for inferecne!')
kbvqa.kbvqa_model.eval()
free_gpu_resources()
return kbvqa
if __name__ == "__main__":
pass
#### Example on how to use the module ####
# Prepare the KBVQA model
# kbvqa = prepare_kbvqa_model()
# Load an image
# image = Image.open('path_to_image.jpg')
# Generate a caption for the image
# caption = kbvqa.get_caption(image)
# Detect objects in the image
# image_with_boxes, detected_objects_str = kbvqa.detect_objects(image)
# Generate an answer to a question about the image
# question = "What is the object in the image?"
# answer = kbvqa.generate_answer(question, caption, detected_objects_str)
# print(f"Answer: {answer}")