File size: 18,383 Bytes
18d1852
e46d486
69b926c
0a62769
eaa0bad
18d1852
49e9e5b
a812c7b
18d1852
11a17ef
6d4d5ac
18d1852
 
eaa0bad
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
11a17ef
eaa0bad
 
 
 
11a17ef
 
 
18d1852
eaa0bad
 
 
ccdbda8
 
18d1852
 
 
 
12f08dc
 
a264403
 
96dd295
 
 
 
69b926c
1f66159
9becb2c
 
 
 
ccdbda8
eaa0bad
cf8c147
 
 
929a4bd
1a38c08
 
11a17ef
b42fa65
18d1852
08655fb
 
11a17ef
fcf48a1
11a17ef
a5c03d9
3fdd1d7
d80fd56
bfdde42
eaa0bad
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
bfdde42
4e59653
bfdde42
4e59653
3fdd1d7
eaa0bad
141a983
 
 
de78d1a
 
eaa0bad
 
 
 
de78d1a
eaa0bad
d80fd56
f4bcc28
 
cf8c147
 
 
 
 
 
f4bcc28
eaa0bad
8624b37
 
 
cf8c147
eaa0bad
 
 
 
cf8c147
eaa0bad
 
18d1852
eaa0bad
 
cf8c147
eaa0bad
cf8c147
3fdd1d7
2250430
 
eaa0bad
2250430
eaa0bad
cf8c147
 
 
 
 
 
 
 
 
 
753c201
 
cc825df
d9364fd
f72214b
2250430
a264403
12f08dc
753c201
e2de402
2b3b1de
753c201
f72214b
 
eaa0bad
 
 
 
 
 
 
 
 
 
 
ffdb10e
96dd295
ffdb10e
a812c7b
ffdb10e
 
2250430
 
ffdb10e
a812c7b
ffdb10e
96dd295
a812c7b
ffdb10e
eaa0bad
96dd295
eaa0bad
96dd295
 
 
 
9becb2c
da52f83
 
 
e2de402
da52f83
e2de402
da52f83
 
 
ffdb10e
f72214b
eaa0bad
cf8c147
 
 
 
 
 
f72214b
2250430
 
7e798e5
87db14a
f72214b
87da9a2
753c201
3fdd1d7
eaa0bad
cf8c147
 
 
 
 
18d1852
 
9becb2c
eaa0bad
cf8c147
 
 
 
 
 
0675d16
 
 
 
18d1852
3fdd1d7
eaa0bad
cf8c147
 
 
 
 
 
 
 
 
 
18d1852
 
9becb2c
1fc0405
9a8f19a
11a17ef
2250430
9becb2c
0ed508e
18d1852
 
 
 
d6f8382
eaa0bad
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
1a4044e
cf8c147
 
18d1852
 
 
 
 
 
 
 
 
3fdd1d7
eaa0bad
cf8c147
 
 
 
 
 
 
 
 
 
1a4044e
cf8c147
 
 
 
e2de402
 
 
18d1852
1f66159
c6c0b0e
87db14a
18d1852
 
3fdd1d7
eaa0bad
cf8c147
 
 
 
 
 
 
 
18d1852
6a15fc4
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
9becb2c
 
 
 
 
 
 
 
 
 
 
 
726be01
9becb2c
 
 
 
 
18d1852
afa0e9e
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd3e03
 
 
 
 
 
 
 
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba99375
436f3b4
afa0e9e
9becb2c
 
 
 
 
 
 
 
 
eaa0bad
 
8d3df8c
 
eaa0bad
 
 
 
 
 
9da8a28
eaa0bad
 
 
 
26f093e
eaa0bad
 
 
 
26f093e
eaa0bad
 
26f093e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model




class StateManager:
    
    # Hints for methods
    # initialize_state: Initializes default values for session state.
    # set_up_widgets: Creates UI elements for model selection and settings.
    # set_slider_value: Generates a slider widget for numerical input.
    # is_widget_disabled: Returns True if UI elements should be disabled.
    # disable_widgets: Disables interactive UI elements during processing.
    # settings_changed: Checks if any model settings have changed.
    # confidance_change: Determines if the confidence level setting has changed.
    # display_model_settings: Shows current model settings in the UI.
    # display_session_state: Displays the current state of the application.
    # update_prev_state: Updates the record of the previous application state.
    # force_reload_model: Reloads the model, clearing and resetting necessary states.


    def __init__(self):
        """
        Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
        """
        
        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])  

    def initialize_state(self):
        """
        Initializes the Streamlit session state with default values for various keys.
        """
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if 'loading_in_progress' not in st.session_state:
            st.session_state['loading_in_progress'] = False
        if 'load_button_clicked' not in st.session_state:
            st.session_state['load_button_clicked'] = False
        if 'force_reload_button_clicked' not in st.session_state:
            st.session_state['force_reload_button_clicked'] = False
        if 'time_taken_to_load_model' not in st.session_state:
            st.session_state['time_taken_to_load_model'] = None
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
        if 'model_loaded' not in st.session_state:
            st.session_state['model_loaded'] = self.is_model_loaded

    def set_up_widgets(self) -> None:
        """
        Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
        """

        self.col1.selectbox("Choose a model:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"], index=1, key='method', disabled=self.is_widget_disabled)
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=0, key='detection_model', disabled=self.is_widget_disabled)
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings

 
        show_model_settings = self.col3.checkbox("Show Model Settings", True,  disabled=self.is_widget_disabled)
        if show_model_settings:
            self.display_model_settings


        
    def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
        """
        Creates a slider widget with the specified parameters, optionally placing it in a specific column.

        Args:
            text (str): Text to display next to the slider.
            min_value (float): Minimum value for the slider.
            max_value (float): Maximum value for the slider.
            value (float): Initial value for the slider.
            step (float): Step size for the slider.
            slider_key_name (str): Unique key for the slider.
            col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
        """
        
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)

    
    @property
    def is_widget_disabled(self):
        return st.session_state['loading_in_progress']

    def disable_widgets(self):
        """
        Disables widgets by setting the 'loading_in_progress' state to True.
        """
        
        st.session_state['loading_in_progress'] = True

        
    @property
    def settings_changed(self):
        """
        Checks if any model settings have changed compared to the previous state.
    
        Returns:
            bool: True if any setting has changed, False otherwise.
        """
        return self.has_state_changed()

        
    @property
    def confidance_change(self):
        """
        Checks if the confidence level setting has changed compared to the previous state.

        Returns:
            bool: True if the confidence level has changed, False otherwise.
        """
        
        return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]

        
    def update_prev_state(self):
        """
        Updates the 'previous_state' in the session state with the current state values.
        """
        
        for key in st.session_state['previous_state']:
            st.session_state['previous_state'][key] = st.session_state[key]

            
    def load_model(self) -> None:
        """
        Loads the KBVQA model based on the chosen method and settings.
    
        - Frees GPU resources before loading.
        - Calls `prepare_kbvqa_model` to create the model.
        - Sets the detection confidence level on the model object.
        - Updates previous state with current settings for change detection.
        - Updates the button label to "Reload Model".
        """
        
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            self.update_prev_state()
            st.session_state['model_loaded'] = True
            st.session_state['button_label'] = "Reload Model"
            free_gpu_resources()
            free_gpu_resources()
            
        except Exception as e:
            st.error(f"Error loading model: {e}")

    def force_reload_model(self) -> None:
        """
        Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.

        - Deletes the current KBVQA model from the session state.
        - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
        - Updates the detection confidence level on the model object.
        - Displays a success message if the model is reloaded successfully.
        """
        
        
        try:
            self.delete_model()
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            self.update_prev_state()
            
            st.session_state['model_loaded'] = True
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading model: {e}")
            free_gpu_resources()
        
    def delete_model(self) -> None:
        """
        This method deletes the current models and calls `free_gpu_resources`.
        """
        
        free_gpu_resources()
        
        if self.is_model_loaded:
            try:
                del st.session_state['kbvqa']
                free_gpu_resources()
                free_gpu_resources()
            except:
                free_gpu_resources()
                free_gpu_resources()
                pass
            
        
    # Function to check if any session state values have changed
    def has_state_changed(self) -> bool:
        """
        Compares current session state with the previous state to identify changes.
    
        Returns:
            bool: True if any change is found, False otherwise.
        """
        for key in st.session_state['previous_state']:
            if key == 'confidence_level':
                continue  # confidence_level tracker is separate
            if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
                
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self) -> KBVQA:
        """
        Retrieve the KBVQA model from the session state.
           
        Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
        """
        return st.session_state.get('kbvqa', None)

    @property
    def is_model_loaded(self) -> bool:
        """
        Checks if the KBVQA model is loaded in the session state.
    
        Returns:
            bool: True if the model is loaded, False otherwise.
        """
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
           st.session_state.kbvqa.all_models_loaded \
           and (st.session_state['previous_state']['method'] is not None 
                and st.session_state['method'] == st.session_state['previous_state']['method'])

    
    def reload_detection_model(self) -> None:
        """
        Reloads only the detection model of the KBVQA model with updated settings.
    
        - Frees GPU resources before reloading.
        - Checks if the model is already loaded.
        - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
        - Updates detection confidence level on the model object.
        - Displays a success message if model is reloaded successfully.
        """
        
        try:
            free_gpu_resources()
            if self.is_model_loaded:
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
                self.update_prev_state
                st.session_state['button_label'] = "Reload Model"

            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key: str, image) -> None:
        """
        Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
    
        This dictionary stores information about each processed image, including:
            - `image`: The original image data.
            - `caption`: Generated caption for the image.
            - `detected_objects_str`: String representation of detected objects.
            - `qa_history`: List of questions and answers related to the image.
            - `analysis_done`: Flag indicating if analysis is complete.
    
        Args:
            image_key (str): Unique key for the image.
            image (obj): The uploaded image data.
   
        """
        
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image) -> Tuple[str, str, object]:
        """
        Analyzes the image using the KBVQA model.

        - Creates a copy of the image to avoid modifying the original.
        - Displays a "Analyzing the image .." message.
        - Calls KBVQA methods to generate a caption and detect objects.
        - Returns the generated caption, detected objects string, and image with bounding boxes.
    
        Args:
            image (obj): The image data to analyze.

    
        Returns:
            tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
        """

        free_gpu_resources()
        free_gpu_resources()
        img = copy.deepcopy(image)
        caption = st.session_state['kbvqa'].get_caption(img)
        image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
        free_gpu_resources()
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
        """
        Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
    
        Args:
            image_key (str): Unique key for the image.
            question (str): The question asked about the image.
            answer (str): The answer generated by the KBVQA model.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))

    
    def get_images_data(self):
        """
        Returns the dictionary containing processed image data from the session state.
    
        Returns:
            dict: The dictionary storing information about processed images.
        """
        return st.session_state['images_data']
    
    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        """
        Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
    
        Args:
            image_key (str): Unique key for the image.
            caption (str): The generated caption for the image.
            detected_objects_str (str): String representation of detected objects.
            analysis_done (bool): Flag indicating if analysis of the image is complete.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })

 
    def resize_image(self, image_input, new_width=None, new_height=None):
        """
        Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
        If both new_width and new_height are provided, the image is resized to those dimensions.
    
        Args:
        image (PIL.Image.Image): The image to resize.
        new_width (int, optional): The target width of the image.
        new_height (int, optional): The target height of the image.
    
        Returns:
        PIL.Image.Image: The resized image.
        """
    
        img = copy.deepcopy(image_input)
        if isinstance(img, str):
            # Open the image from a file path
            image = Image.open(img)
        elif isinstance(img, Image.Image):
            # Use the image directly if it's already a PIL Image object
            image = img
        else:
            raise ValueError("image_input must be a file path or a PIL Image object")
                
        if new_width is not None and new_height is None:
            # Calculate new height to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_width / original_width
            new_height = int(original_height * ratio)
        elif new_width is None and new_height is not None:
            # Calculate new width to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_height / original_height
            new_width = int(original_width * ratio)
        elif new_width is None and new_height is None:
            raise ValueError("At least one of new_width or new_height must be provided")
    
        # Resize the image
        resized_image = image.resize((new_width, new_height))
        return resized_image

    
 
    def display_message(self, message, message_type):
        if message_type == "warning":
            st.warning(message)
        elif message_type == "text":
            st.text(message)
        elif message_type == "success":
            st.success(messae)
        elif message_type == "write":
            st.write(message)
        else: st.error("Message type unknown") 
        
        
    @property
    def display_model_settings(self):
        """
        Displays a table of current model settings in the third column.
    
        """
        self.col3.write("##### Current Model Settings:")
        data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
        df = pd.DataFrame(data).reset_index(drop=True)
        return self.col3.write(df)

    
    def display_session_state(self, col):
        """
        Displays a table of the complete application state..
        """
  
        col.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data).reset_index(drop=True)
        col.write(df)