File size: 15,962 Bytes
18d1852
e46d486
69b926c
0a62769
18d1852
49e9e5b
a812c7b
18d1852
11a17ef
6d4d5ac
18d1852
 
 
11a17ef
 
 
 
18d1852
 
 
 
 
12f08dc
 
4170a5f
74d450c
a264403
 
96dd295
 
 
 
69b926c
 
9becb2c
 
 
 
a264403
9becb2c
2152f1f
2957e90
1b71503
cf8c147
 
 
929a4bd
b40d522
 
11a17ef
0ed508e
18d1852
08655fb
 
11a17ef
fcf48a1
11a17ef
 
3fdd1d7
d80fd56
bfdde42
0ed508e
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
bfdde42
4e59653
bfdde42
4e59653
3fdd1d7
9becb2c
141a983
 
 
de78d1a
 
 
d80fd56
f4bcc28
 
cf8c147
 
 
 
 
 
f4bcc28
18d1852
afa0e9e
18d1852
cf8c147
 
 
 
11a17ef
69b926c
d6a4897
9becb2c
fa064be
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
3689b26
18d1852
e1f7f87
11c350b
3fdd1d7
18d1852
d0e9fe6
cf8c147
 
 
 
 
 
 
 
 
 
753c201
 
cc825df
d9364fd
f72214b
 
a264403
12f08dc
753c201
2b3b1de
753c201
f72214b
 
a812c7b
ffdb10e
96dd295
ffdb10e
a812c7b
ffdb10e
 
 
 
a812c7b
ffdb10e
96dd295
a812c7b
ffdb10e
96dd295
 
 
 
 
 
 
9becb2c
da52f83
 
 
 
 
 
 
ffdb10e
f72214b
 
cf8c147
 
 
 
 
 
f72214b
 
 
87da9a2
753c201
3fdd1d7
18d1852
cf8c147
 
 
 
 
18d1852
 
9becb2c
18d1852
cf8c147
 
 
 
 
 
c70ac46
18d1852
3fdd1d7
cd3678b
cf8c147
 
 
 
 
 
 
 
 
 
18d1852
 
9becb2c
0ed508e
1fc0405
9a8f19a
11a17ef
9becb2c
 
0ed508e
18d1852
 
 
 
d6f8382
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
 
 
 
 
 
 
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
d182243
18d1852
 
 
 
3fdd1d7
6a15fc4
cf8c147
 
 
 
 
 
 
 
18d1852
6a15fc4
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
9becb2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
afa0e9e
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd3e03
 
 
 
 
 
 
 
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba99375
436f3b4
afa0e9e
9becb2c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import pandas as pd
import copy
import time
from PIL import Image
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model




class StateManager:

    def __init__(self):
        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])  

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if 'loading_in_progress' not in st.session_state:
            st.session_state['loading_in_progress'] = False
        if 'load_button_clicked' not in st.session_state:
            st.session_state['load_button_clicked'] = False
        if 'force_reload_button_clicked' not in st.session_state:
            st.session_state['force_reload_button_clicked'] = False
        if 'time_taken_to_load_model' not in st.session_state:
            st.session_state['time_taken_to_load_model'] = 0
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
        if 'model_loaded' not in st.session_state:
            st.session_state['model_loaded'] = self.is_model_loaded

    


    def set_up_widgets(self):
        """
        Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
        """

        self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method', disabled=self.is_widget_disabled)
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', disabled=self.is_widget_disabled)
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings

 
        show_model_settings = self.col3.checkbox("Show Model Settings", True,  disabled=self.is_widget_disabled)
        if show_model_settings:
            self.display_model_settings()


        
    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
        """
        Creates a slider widget with the specified parameters, optionally placing it in a specific column.

        Args:
            text (str): Text to display next to the slider.
            min_value (float): Minimum value for the slider.
            max_value (float): Maximum value for the slider.
            value (float): Initial value for the slider.
            step (float): Step size for the slider.
            slider_key_name (str): Unique key for the slider.
            col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
        """
        
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)


    @property
    def is_widget_disabled(self):
        return st.session_state['loading_in_progress']

    def disable_widgets(self):
        st.session_state['loading_in_progress'] = True
        
    @property
    def settings_changed(self):
        """
        Checks if any model settings have changed compared to the previous state.
    
        Returns:
            bool: True if any setting has changed, False otherwise.
        """
        return self.has_state_changed()

    
    def display_model_settings(self):
        """
        Displays a table of current model settings in the third column.
    
        """
        self.col3.write("##### Current Model Settings:")
        data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model' ]]
        df = pd.DataFrame(data)
        #styled_df = df.style.set_properties(**{'background-color': 'white', 'color': 'black', 'border-color': 'black'}).set_table_styles([{'selector': 'th','props': [('background-color', 'gray'), ('font-weight', 'bold')]}])
        self.col3.write(df)

    
    def display_session_state(self):
        """
        Displays a table of the complete application state..
        """
  
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data).reset_index(drop=True)
        st.write(df)
        

    def load_model(self):
        """
        Loads the KBVQA model based on the chosen method and settings.
    
        - Frees GPU resources before loading.
        - Calls `prepare_kbvqa_model` to create the model.
        - Sets the detection confidence level on the model object.
        - Updates previous state with current settings for change detection.
        - Updates the button label to "Reload Model".
        """
        
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            st.session_state['model_loaded'] = True
            st.session_state['button_label'] = "Reload Model"
            free_gpu_resources()
            
        except Exception as e:
            st.error(f"Error loading model: {e}")

    def force_reload_model(self):
        try:
            self.delete_model()
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            st.session_state['model_loaded'] = True
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading model: {e}")
            free_gpu_resources()
        
    def delete_model(self):
        """
        Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
        """
        
        free_gpu_resources()
        
        if self.is_model_loaded:
            try:
                del st.session_state['kbvqa']
                free_gpu_resources()
            except:
                free_gpu_resources()
                pass
            
        
    # Function to check if any session state values have changed
    def has_state_changed(self):
        """
        Compares current session state with the previous state to identify changes.
    
        Returns:
            bool: True if any change is found, False otherwise.
        """
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self):
        """
        Retrieve the KBVQA model from the session state.
           
        Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
        """
        return st.session_state.get('kbvqa', None)

    @property
    def is_model_loaded(self):
        """
        Checks if the KBVQA model is loaded in the session state.
    
        Returns:
            bool: True if the model is loaded, False otherwise.
        """
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and st.session_state.kbvqa.all_models_loaded

    
    def reload_detection_model(self):
        """
        Reloads only the detection model of the KBVQA model with updated settings.
    
        - Frees GPU resources before reloading.
        - Checks if the model is already loaded.
        - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
        - Updates detection confidence level on the model object.
        - Displays a success message if model is reloaded successfully.
        """
        
        try:
            free_gpu_resources()
            if self.is_model_loaded:

                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
                st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
                st.session_state['button_label'] = "Reload Model"

            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key, image, kbvqa):
        """
        Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
    
        This dictionary stores information about each processed image, including:
            - `image`: The original image data.
            - `caption`: Generated caption for the image.
            - `detected_objects_str`: String representation of detected objects.
            - `qa_history`: List of questions and answers related to the image.
            - `analysis_done`: Flag indicating if analysis is complete.
    
        Args:
            image_key (str): Unique key for the image.
            image (obj): The uploaded image data.
            kbvqa (KBVQA object): The loaded KBVQA model.
        """
        
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image, kbvqa):
        """
        Analyzes the image using the KBVQA model.

        - Creates a copy of the image to avoid modifying the original.
        - Displays a "Analyzing the image .." message.
        - Calls KBVQA methods to generate a caption and detect objects.
        - Returns the generated caption, detected objects string, and image with bounding boxes.
    
        Args:
            image (obj): The image data to analyze.
            kbvqa (KBVQA object): The loaded KBVQA model.
    
        Returns:
            tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
        """
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key, question, answer, prompt_length):
        """
        Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
    
        Args:
            image_key (str): Unique key for the image.
            question (str): The question asked about the image.
            answer (str): The answer generated by the KBVQA model.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))

    
    def get_images_data(self):
        """
        Returns the dictionary containing processed image data from the session state.
    
        Returns:
            dict: The dictionary storing information about processed images.
        """
        return st.session_state['images_data']
    
    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        """
        Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
    
        Args:
            image_key (str): Unique key for the image.
            caption (str): The generated caption for the image.
            detected_objects_str (str): String representation of detected objects.
            analysis_done (bool): Flag indicating if analysis of the image is complete.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })

 
    def resize_image(self, image_input, new_width=None, new_height=None):
        """
        Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
        If both new_width and new_height are provided, the image is resized to those dimensions.
    
        Args:
        image (PIL.Image.Image): The image to resize.
        new_width (int, optional): The target width of the image.
        new_height (int, optional): The target height of the image.
    
        Returns:
        PIL.Image.Image: The resized image.
        """
    
        img = copy.deepcopy(image_input)
        if isinstance(img, str):
            # Open the image from a file path
            image = Image.open(img)
        elif isinstance(img, Image.Image):
            # Use the image directly if it's already a PIL Image object
            image = img
        else:
            raise ValueError("image_input must be a file path or a PIL Image object")
                
        if new_width is not None and new_height is None:
            # Calculate new height to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_width / original_width
            new_height = int(original_height * ratio)
        elif new_width is None and new_height is not None:
            # Calculate new width to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_height / original_height
            new_width = int(original_width * ratio)
        elif new_width is None and new_height is None:
            raise ValueError("At least one of new_width or new_height must be provided")
    
        # Resize the image
        resized_image = image.resize((new_width, new_height))
        return resized_image

    
 
    def display_message(self, message, message_type):
        if message_type == "warning":
            st.warning(message)
        elif message_type == "text":
            st.text(message)
        elif message_type == "success":
            st.success(messae)
        elif message_type == "write":
            st.write(message)
        else: st.error("Message type unknown")