|
import streamlit as st |
|
import torch |
|
import copy |
|
import os |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from typing import Tuple, Optional |
|
from my_model.gen_utilities import free_gpu_resources |
|
from my_model.captioner.image_captioning import ImageCaptioningModel |
|
from my_model.object_detection import ObjectDetector |
|
import my_model.config.kbvqa_config as config |
|
|
|
|
|
class KBVQA(): |
|
""" |
|
The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model. |
|
It integrates various components such as an image captioning model, object detection model, and a fine-tuned |
|
language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions. |
|
|
|
Attributes: |
|
kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA. |
|
quantization (str): The quantization setting for the model (e.g., '4bit', '8bit'). |
|
max_context_window (int): The maximum number of tokens allowed in the model's context window. |
|
add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer. |
|
trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer. |
|
use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer. |
|
low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage. |
|
kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model. |
|
captioner (Optional[ImageCaptioningModel]): The model used for generating image captions. |
|
detector (Optional[ObjectDetector]): The object detection model. |
|
detection_model (Optional[str]): The name of the object detection model. |
|
detection_confidence (Optional[float]): The confidence threshold for object detection. |
|
kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA. |
|
bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model. |
|
access_token (str): Access token for Hugging Face API. |
|
|
|
Methods: |
|
create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting. |
|
load_caption_model: Loads the image captioning model. |
|
get_caption: Generates a caption for a given image. |
|
load_detector: Loads the object detection model. |
|
detect_objects: Detects objects in a given image. |
|
load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer. |
|
all_models_loaded: Checks if all the required models are loaded. |
|
force_reload_model: Forces a reload of all models, freeing up GPU resources. |
|
format_prompt: Formats the prompt for the KBVQA model. |
|
generate_answer: Generates an answer to a given question using the KBVQA model. |
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.kbvqa_model_name = config.KBVQA_MODEL_NAME |
|
self.quantization = config.QUANTIZATION |
|
self.max_context_window = config.MAX_CONTEXT_WINDOW |
|
self.add_eos_token = config.ADD_EOS_TOKEN |
|
self.trust_remote = config.TRUST_REMOTE |
|
self.use_fast = config.USE_FAST |
|
self.low_cpu_mem_usage=config.LOW_CPU_MEM_USAGE |
|
self.kbvqa_tokenizer = None |
|
self.captioner = None |
|
self.detector = None |
|
self.detection_model = None |
|
self.detection_confidence = None |
|
self.kbvqa_model = None |
|
self.bnb_config = self.create_bnb_config() |
|
self.access_token = config.HUGGINGFACE_TOKEN |
|
|
|
|
|
|
|
def create_bnb_config(self) -> BitsAndBytesConfig: |
|
""" |
|
Creates a BitsAndBytes configuration based on the quantization setting. |
|
Returns: |
|
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model. |
|
""" |
|
if self.quantization == '4bit': |
|
return BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
elif self.quantization == '8bit': |
|
return BitsAndBytesConfig( |
|
load_in_8bit=True, |
|
bnb_8bit_use_double_quant=True, |
|
bnb_8bit_quant_type="nf4", |
|
bnb_8bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
def load_caption_model(self) -> None: |
|
""" |
|
Loads the image captioning model into the KBVQA instance. |
|
""" |
|
|
|
self.captioner = ImageCaptioningModel() |
|
self.captioner.load_model() |
|
|
|
def get_caption(self, img: Image.Image) -> str: |
|
""" |
|
Generates a caption for a given image using the image captioning model. |
|
|
|
Args: |
|
img (PIL.Image.Image): The image for which to generate a caption. |
|
|
|
Returns: |
|
str: The generated caption for the image. |
|
""" |
|
|
|
return self.captioner.generate_caption(img) |
|
|
|
def load_detector(self, model: str) -> None: |
|
""" |
|
Loads the object detection model. |
|
|
|
Args: |
|
model (str): The name of the object detection model to load. |
|
""" |
|
|
|
self.detector = ObjectDetector() |
|
self.detector.load_model(model) |
|
|
|
def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]: |
|
""" |
|
Detects objects in a given image using the loaded object detection model. |
|
|
|
Args: |
|
img (PIL.Image.Image): The image in which to detect objects. |
|
|
|
Returns: |
|
tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects. |
|
""" |
|
|
|
image = self.detector.process_image(img) |
|
detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence) |
|
image_with_boxes = self.detector.draw_boxes(img, detected_objects_list) |
|
return image_with_boxes, detected_objects_string |
|
|
|
def load_fine_tuned_model(self) -> None: |
|
""" |
|
Loads the fine-tuned KBVQA model along with its tokenizer. |
|
""" |
|
|
|
self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, |
|
device_map="auto", |
|
low_cpu_mem_usage=True, |
|
quantization_config=self.bnb_config, |
|
token=self.access_token) |
|
|
|
self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, |
|
use_fast=self.use_fast, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=self.trust_remote, |
|
add_eos_token=self.add_eos_token, |
|
token=self.access_token) |
|
|
|
|
|
@property |
|
def all_models_loaded(self): |
|
""" |
|
Checks if all the required models (KBVQA, captioner, detector) are loaded. |
|
|
|
Returns: |
|
bool: True if all models are loaded, False otherwise. |
|
""" |
|
|
|
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None |
|
|
|
def force_reload_model(self): |
|
""" |
|
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`. |
|
""" |
|
|
|
free_gpu_resources() |
|
if self.kbvqa_model is not None: |
|
del self.kbvqa_model |
|
if self.captioner is not None: |
|
del self.captioner |
|
if self.detector is not None: |
|
del self.detector |
|
|
|
free_gpu_resources() |
|
|
|
|
|
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str: |
|
""" |
|
Formats the prompt for the KBVQA model based on the provided parameters. |
|
|
|
Args: |
|
current_query (str): The current question to be answered. |
|
history (str, optional): The history of previous interactions. |
|
sys_prompt (str, optional): The system prompt or instructions for the model. |
|
caption (str, optional): The caption of the image. |
|
objects (str, optional): The detected objects in the image. |
|
|
|
Returns: |
|
str: The formatted prompt for the KBVQA model. |
|
""" |
|
|
|
B_SENT = '<s>' |
|
E_SENT = '</s>' |
|
B_INST = '[INST]' |
|
E_INST = '[/INST]' |
|
B_SYS = '<<SYS>>\n' |
|
E_SYS = '\n<</SYS>>\n\n' |
|
B_CAP = '[CAP]' |
|
E_CAP = '[/CAP]' |
|
B_QES = '[QES]' |
|
E_QES = '[/QES]' |
|
B_OBJ = '[OBJ]' |
|
E_OBJ = '[/OBJ]' |
|
current_query = current_query.strip() |
|
if sys_prompt is None: |
|
sys_prompt = config.SYSTEM_PROMPT.strip() |
|
if history is None: |
|
if objects is None: |
|
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}""" |
|
else: |
|
p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}""" |
|
else: |
|
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}""" |
|
|
|
return p |
|
|
|
|
|
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str: |
|
""" |
|
Generates an answer to a given question using the KBVQA model. |
|
|
|
Args: |
|
question (str): The question to be answered. |
|
caption (str): The caption of the image related to the question. |
|
detected_objects_str (str): The string representation of detected objects in the image. |
|
|
|
Returns: |
|
str: The generated answer to the question. |
|
""" |
|
|
|
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str) |
|
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt)) |
|
if num_tokens > self.max_context_window: |
|
st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector") |
|
return |
|
|
|
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda') |
|
input_ids = model_inputs["input_ids"] |
|
output_ids = self.kbvqa_model.generate(input_ids) |
|
index = input_ids.shape[1] |
|
history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False) |
|
output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True) |
|
|
|
return output_text.capitalize() |
|
|
|
def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA: |
|
""" |
|
Prepares the KBVQA model for use, including loading necessary sub-models. |
|
|
|
Args: |
|
only_reload_detection_model (bool): If True, only the object detection model is reloaded. |
|
|
|
Returns: |
|
KBVQA: An instance of the KBVQA model ready for inference. |
|
""" |
|
|
|
free_gpu_resources() |
|
kbvqa = KBVQA() |
|
kbvqa.detection_model = st.session_state.detection_model |
|
|
|
with st.spinner('Loading model...'): |
|
|
|
if not only_reload_detection_model: |
|
st.text('this should take no more than a few minutes!') |
|
progress_bar = st.progress(0) |
|
kbvqa.load_detector(kbvqa.detection_model) |
|
progress_bar.progress(33) |
|
kbvqa.load_caption_model() |
|
free_gpu_resources() |
|
progress_bar.progress(75) |
|
st.text('Almost there :)') |
|
kbvqa.load_fine_tuned_model() |
|
free_gpu_resources() |
|
progress_bar.progress(100) |
|
else: |
|
progress_bar = st.progress(0) |
|
kbvqa.load_detector(kbvqa.detection_model) |
|
progress_bar.progress(100) |
|
|
|
if kbvqa.all_models_loaded: |
|
st.success('Model loaded successfully and ready for inferecne!') |
|
kbvqa.kbvqa_model.eval() |
|
free_gpu_resources() |
|
return kbvqa |
|
|
|
|
|
|