File size: 12,559 Bytes
c59fc6b
 
9347b1e
139bf60
0b430a0
c59fc6b
61e10b7
2997bb2
d26dd8d
c59fc6b
61e10b7
c59fc6b
 
 
61e10b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
61e10b7
 
 
 
 
 
 
 
c59fc6b
 
 
518eb6e
824d7ec
c59fc6b
61e10b7
 
 
c59fc6b
 
 
08ec8d2
 
 
 
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e10b7
 
 
 
 
567d6d1
c59fc6b
 
61e10b7
 
 
 
 
 
 
 
 
 
c59fc6b
 
 
61e10b7
 
 
 
 
 
 
c59fc6b
 
 
 
61e10b7
 
 
 
 
 
 
 
 
 
 
c59fc6b
6f1c42e
c59fc6b
 
 
61e10b7
 
 
 
 
ec4889b
 
 
 
 
 
 
 
 
 
 
 
c59fc6b
 
 
 
61e10b7
 
 
 
 
 
 
c59fc6b
 
e57843e
61e10b7
 
 
 
e57843e
 
 
 
 
 
 
 
 
 
 
61e10b7
 
 
c59fc6b
61e10b7
 
 
 
 
 
c59fc6b
61e10b7
 
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
9335258
 
c59fc6b
 
 
 
 
 
 
 
 
 
 
61e10b7
 
 
 
 
 
 
 
 
 
 
 
91f466a
c59fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e10b7
 
 
 
 
 
 
 
 
 
 
97bc44b
c59fc6b
c690614
c59fc6b
8024886
e57843e
7ea3839
8024886
7ea3839
 
 
 
 
df43e4c
8838f09
7ea3839
 
 
 
 
 
 
61e10b7
c59fc6b
c62c890
c59fc6b
97bc44b
c59fc6b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import streamlit as st
import torch
import copy
import os
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Tuple, Optional
from my_model.gen_utilities import free_gpu_resources
from my_model.captioner.image_captioning import ImageCaptioningModel
from my_model.object_detection import ObjectDetector
import my_model.config.kbvqa_config as config


class KBVQA():
    """
    The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model. 
    It integrates various components such as an image captioning model, object detection model, and a fine-tuned 
    language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions.

    Attributes:
        kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA.
        quantization (str): The quantization setting for the model (e.g., '4bit', '8bit').
        max_context_window (int): The maximum number of tokens allowed in the model's context window.
        add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer.
        trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer.
        use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer.
        low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage.
        kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model.
        captioner (Optional[ImageCaptioningModel]): The model used for generating image captions.
        detector (Optional[ObjectDetector]): The object detection model.
        detection_model (Optional[str]): The name of the object detection model.
        detection_confidence (Optional[float]): The confidence threshold for object detection.
        kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA.
        bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model.
        access_token (str): Access token for Hugging Face API.

    Methods:
        create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting.
        load_caption_model: Loads the image captioning model.
        get_caption: Generates a caption for a given image.
        load_detector: Loads the object detection model.
        detect_objects: Detects objects in a given image.
        load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer.
        all_models_loaded: Checks if all the required models are loaded.
        force_reload_model: Forces a reload of all models, freeing up GPU resources.
        format_prompt: Formats the prompt for the KBVQA model.
        generate_answer: Generates an answer to a given question using the KBVQA model.
    """

    def __init__(self):
        
        self.kbvqa_model_name = config.KBVQA_MODEL_NAME
        self.quantization = config.QUANTIZATION
        self.max_context_window = config.MAX_CONTEXT_WINDOW
        self.add_eos_token = config.ADD_EOS_TOKEN
        self.trust_remote = config.TRUST_REMOTE
        self.use_fast = config.USE_FAST
        self.low_cpu_mem_usage=config.LOW_CPU_MEM_USAGE
        self.kbvqa_tokenizer = None
        self.captioner = None
        self.detector = None
        self.detection_model = None
        self.detection_confidence = None 
        self.kbvqa_model = None
        self.bnb_config = self.create_bnb_config()
        self.access_token = config.HUGGINGFACE_TOKEN
 

 
    def create_bnb_config(self) -> BitsAndBytesConfig:
        """
        Creates a BitsAndBytes configuration based on the quantization setting.
        Returns:
            BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
        """
        if self.quantization == '4bit':
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        elif self.quantization == '8bit':
            return BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_8bit_use_double_quant=True,
                bnb_8bit_quant_type="nf4",
                bnb_8bit_compute_dtype=torch.bfloat16
            )


    def load_caption_model(self) -> None:
        """
        Loads the image captioning model into the KBVQA instance.
        """
        
        self.captioner = ImageCaptioningModel()
        self.captioner.load_model()

    def get_caption(self, img: Image.Image) -> str:
        """
        Generates a caption for a given image using the image captioning model.

        Args:
            img (PIL.Image.Image): The image for which to generate a caption.

        Returns:
            str: The generated caption for the image.
        """

        return self.captioner.generate_caption(img)

    def load_detector(self, model: str) -> None:
        """
        Loads the object detection model.

        Args:
            model (str): The name of the object detection model to load.
        """

        self.detector = ObjectDetector()
        self.detector.load_model(model)

    def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]:
        """
        Detects objects in a given image using the loaded object detection model.

        Args:
            img (PIL.Image.Image): The image in which to detect objects.

        Returns:
            tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects.
        """
        
        image = self.detector.process_image(img)
        detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=self.detection_confidence)
        image_with_boxes = self.detector.draw_boxes(img, detected_objects_list)
        return image_with_boxes, detected_objects_string

    def load_fine_tuned_model(self) -> None:
        """
        Loads the fine-tuned KBVQA model along with its tokenizer.
        """
        
        self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, 
                                                                device_map="auto", 
                                                                low_cpu_mem_usage=True, 
                                                                quantization_config=self.bnb_config,
                                                                token=self.access_token)
        
        self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, 
                                                             use_fast=self.use_fast, 
                                                             low_cpu_mem_usage=True, 
                                                             trust_remote_code=self.trust_remote, 
                                                             add_eos_token=self.add_eos_token,
                                                             token=self.access_token)


    @property
    def all_models_loaded(self):
        """
        Checks if all the required models (KBVQA, captioner, detector) are loaded.

        Returns:
            bool: True if all models are loaded, False otherwise.
        """
        
        return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None

    def force_reload_model(self):
        """
        Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
        """
        
        free_gpu_resources()
        if self.kbvqa_model is not None:
            del self.kbvqa_model
        if self.captioner is not None:
            del self.captioner
        if self.detector is not None:
            del self.detector

        free_gpu_resources()


    def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
        """
        Formats the prompt for the KBVQA model based on the provided parameters.

        Args:
            current_query (str): The current question to be answered.
            history (str, optional): The history of previous interactions.
            sys_prompt (str, optional): The system prompt or instructions for the model.
            caption (str, optional): The caption of the image.
            objects (str, optional): The detected objects in the image.

        Returns:
            str: The formatted prompt for the KBVQA model.
        """
    
        B_SENT = '<s>'
        E_SENT = '</s>'
        B_INST = '[INST]'
        E_INST = '[/INST]'
        B_SYS = '<<SYS>>\n'
        E_SYS = '\n<</SYS>>\n\n'
        B_CAP = '[CAP]'
        E_CAP = '[/CAP]'
        B_QES = '[QES]'
        E_QES = '[/QES]'
        B_OBJ = '[OBJ]'
        E_OBJ = '[/OBJ]'
        current_query = current_query.strip()
        if sys_prompt is None:
            sys_prompt = config.SYSTEM_PROMPT.strip()
        if history is None:
            if objects is None:
                p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}"""
            else:
              p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}"""
        else:
            p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
        
        return p
       

    def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
        """
        Generates an answer to a given question using the KBVQA model.

        Args:
            question (str): The question to be answered.
            caption (str): The caption of the image related to the question.
            detected_objects_str (str): The string representation of detected objects in the image.

        Returns:
            str: The generated answer to the question.
        """
        
        prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
        num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
        if num_tokens > self.max_context_window:
            st.write(f"Prompt too long with {num_tokens} tokens, consider increasing the confidence threshold for the object detector")
            return

        model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
        input_ids = model_inputs["input_ids"]
        output_ids = self.kbvqa_model.generate(input_ids)
        index = input_ids.shape[1] # needed to avoid printing the input prompt
        history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False)
        output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True)

        return output_text.capitalize()

def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
    """
    Prepares the KBVQA model for use, including loading necessary sub-models.

    Args:
        only_reload_detection_model (bool): If True, only the object detection model is reloaded.

    Returns:
        KBVQA: An instance of the KBVQA model ready for inference.
    """
    
    free_gpu_resources()
    kbvqa = KBVQA()
    kbvqa.detection_model = st.session_state.detection_model
    # Progress bar for model loading
    with st.spinner('Loading model...'):

        if not only_reload_detection_model:
            st.text('this should take no more than a few minutes!')
            progress_bar = st.progress(0)
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(33)
            kbvqa.load_caption_model()
            free_gpu_resources()
            progress_bar.progress(75)
            st.text('Almost there :)')
            kbvqa.load_fine_tuned_model()
            free_gpu_resources()
            progress_bar.progress(100)
        else:
            progress_bar = st.progress(0)
            kbvqa.load_detector(kbvqa.detection_model)
            progress_bar.progress(100)
            
    if kbvqa.all_models_loaded:
        st.success('Model loaded successfully and ready for inferecne!')
        kbvqa.kbvqa_model.eval()
        free_gpu_resources()
        return kbvqa