File size: 13,884 Bytes
18d1852
e46d486
0a62769
18d1852
49e9e5b
ffdb10e
18d1852
11a17ef
6d4d5ac
18d1852
 
 
11a17ef
 
 
 
18d1852
 
 
 
 
12f08dc
 
4170a5f
74d450c
b7c642c
f4bcc28
a264403
 
 
 
 
09ff02d
2152f1f
2957e90
1b71503
cf8c147
 
 
bfdde42
11a17ef
 
 
 
18d1852
08655fb
 
11a17ef
 
 
 
3fdd1d7
d80fd56
bfdde42
 
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
bfdde42
 
 
 
3fdd1d7
d80fd56
f4bcc28
 
cf8c147
 
 
 
 
 
f4bcc28
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
 
11a17ef
08655fb
d6a4897
a264403
11a17ef
18d1852
3fdd1d7
18d1852
cf8c147
 
 
 
3689b26
18d1852
 
 
3fdd1d7
18d1852
d0e9fe6
cf8c147
 
 
 
 
 
 
 
 
 
753c201
 
cc825df
d9364fd
f72214b
 
a264403
12f08dc
753c201
 
f72214b
 
ffdb10e
 
 
 
 
 
 
 
 
 
 
 
f72214b
 
cf8c147
 
 
 
 
 
f72214b
 
 
87da9a2
753c201
3fdd1d7
18d1852
cf8c147
 
 
 
 
18d1852
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
 
3fdd1d7
cd3678b
cf8c147
 
 
 
 
 
 
 
 
 
18d1852
 
 
1fc0405
9a8f19a
11a17ef
18d1852
 
 
 
d6f8382
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
 
 
 
 
 
 
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18d1852
d182243
18d1852
 
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
18d1852
 
 
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
18d1852
 
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cd3e03
 
 
 
 
 
 
 
f993077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba99375
436f3b4
3fdd1d7
18d1852
cf8c147
 
 
 
 
 
 
 
 
18d1852
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import pandas as pd
import copy
from PIL import Image
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model, force_reload_model




class StateManager:

    def __init__(self):
        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])  

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state: 
            st.session_state['button_label'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
        if 'model_loaded' not in st.session_state:
            st.session_state['model_loaded'] = False
        if 'loading_in_progress' not in st.session_state:
            st.session_state['loading_in_progress'] = False

            


    def set_up_widgets(self):
        """
        Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
        """

        self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings

 
        show_model_settings = self.col3.checkbox("Show Model Settings", False)
        if show_model_settings:
            self.display_model_settings()


        
    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
        """
        Creates a slider widget with the specified parameters, optionally placing it in a specific column.

        Args:
            text (str): Text to display next to the slider.
            min_value (float): Minimum value for the slider.
            max_value (float): Maximum value for the slider.
            value (float): Initial value for the slider.
            step (float): Step size for the slider.
            slider_key_name (str): Unique key for the slider.
            col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
        """
        
        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name)

        
    @property
    def settings_changed(self):
        """
        Checks if any model settings have changed compared to the previous state.
    
        Returns:
            bool: True if any setting has changed, False otherwise.
        """
        return self.has_state_changed()

    
    def display_model_settings(self):
        """
        Displays a table of current model settings in the third column.
    
        Uses formatted HTML to style the table for better readability.
        """
        self.col3.write("##### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'white', 'color': 'black', 'border-color': 'black'}).set_table_styles([{'selector': 'th','props': [('background-color', 'gray'), ('font-weight', 'bold')]}])
        self.col3.table(styled_df)

    
    def display_session_state(self):
        """
        Displays a table of the complete application state..
        """
  
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)
        

    def load_model(self):
        """
        Loads the KBVQA model based on the chosen method and settings.
    
        - Frees GPU resources before loading.
        - Calls `prepare_kbvqa_model` to create the model.
        - Sets the detection confidence level on the model object.
        - Updates previous state with current settings for change detection.
        - Updates the button label to "Reload Model".
        """
        
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            st.session_state['model_loaded'] = True
            st.session_state['button_label'] = "Reload Model"
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

    def force_reload_kbvqa(self):
        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = force_reload_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            st.session_state['model_loaded'] = True
        except Exception as e:
            st.error(f"Error loading model: {e}")
        
        
    # Function to check if any session state values have changed
    def has_state_changed(self):
        """
        Compares current session state with the previous state to identify changes.
    
        Returns:
            bool: True if any change is found, False otherwise.
        """
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else: return False  # No changes found   

    
    def get_model(self):
        """
        Retrieve the KBVQA model from the session state.
           
        Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
        """
        return st.session_state.get('kbvqa', None)

    
    def is_model_loaded(self):
        """
        Checks if the KBVQA model is loaded in the session state.
    
        Returns:
            bool: True if the model is loaded, False otherwise.
        """
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    
    def reload_detection_model(self):
        """
        Reloads only the detection model of the KBVQA model with updated settings.
    
        - Frees GPU resources before reloading.
        - Checks if the model is already loaded.
        - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
        - Updates detection confidence level on the model object.
        - Displays a success message if model is reloaded successfully.
        """
        
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    
    def process_new_image(self, image_key, image, kbvqa):
        """
        Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
    
        This dictionary stores information about each processed image, including:
            - `image`: The original image data.
            - `caption`: Generated caption for the image.
            - `detected_objects_str`: String representation of detected objects.
            - `qa_history`: List of questions and answers related to the image.
            - `analysis_done`: Flag indicating if analysis is complete.
    
        Args:
            image_key (str): Unique key for the image.
            image (obj): The uploaded image data.
            kbvqa (KBVQA object): The loaded KBVQA model.
        """
        
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    
    def analyze_image(self, image, kbvqa):
        """
        Analyzes the image using the KBVQA model.

        - Creates a copy of the image to avoid modifying the original.
        - Displays a "Analyzing the image .." message.
        - Calls KBVQA methods to generate a caption and detect objects.
        - Returns the generated caption, detected objects string, and image with bounding boxes.
    
        Args:
            image (obj): The image data to analyze.
            kbvqa (KBVQA object): The loaded KBVQA model.
    
        Returns:
            tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
        """
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    
    def add_to_qa_history(self, image_key, question, answer):
        """
        Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
    
        Args:
            image_key (str): Unique key for the image.
            question (str): The question asked about the image.
            answer (str): The answer generated by the KBVQA model.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    
    def get_images_data(self):
        """
        Returns the dictionary containing processed image data from the session state.
    
        Returns:
            dict: The dictionary storing information about processed images.
        """
        return st.session_state['images_data']

    def resize_image(self, image_input, new_width=None, new_height=None):
        """
        Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
        If both new_width and new_height are provided, the image is resized to those dimensions.
    
        Args:
        image (PIL.Image.Image): The image to resize.
        new_width (int, optional): The target width of the image.
        new_height (int, optional): The target height of the image.
    
        Returns:
        PIL.Image.Image: The resized image.
        """
    
        img = copy.deepcopy(image_input)
        if isinstance(img, str):
            # Open the image from a file path
            image = Image.open(img)
        elif isinstance(img, Image.Image):
            # Use the image directly if it's already a PIL Image object
            image = img
        else:
            raise ValueError("image_input must be a file path or a PIL Image object")
                
        if new_width is not None and new_height is None:
            # Calculate new height to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_width / original_width
            new_height = int(original_height * ratio)
        elif new_width is None and new_height is not None:
            # Calculate new width to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_height / original_height
            new_width = int(original_width * ratio)
        elif new_width is None and new_height is None:
            raise ValueError("At least one of new_width or new_height must be provided")
    
        # Resize the image
        resized_image = image.resize((new_width, new_height))
        return resized_image

    
    
    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        """
        Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
    
        Args:
            image_key (str): Unique key for the image.
            caption (str): The generated caption for the image.
            detected_objects_str (str): String representation of detected objects.
            analysis_done (bool): Flag indicating if analysis of the image is complete.
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })