Spaces:
Running
Running
File size: 18,032 Bytes
18d1852 e46d486 69b926c 0a62769 eaa0bad 18d1852 49e9e5b a812c7b 18d1852 11a17ef 6d4d5ac 18d1852 eaa0bad 18d1852 11a17ef eaa0bad 11a17ef 18d1852 eaa0bad ccdbda8 18d1852 12f08dc a264403 96dd295 69b926c 1f66159 9becb2c ccdbda8 eaa0bad cf8c147 929a4bd bbdd166 b0a5499 11a17ef b42fa65 18d1852 08655fb 11a17ef fcf48a1 11a17ef a5c03d9 3fdd1d7 d80fd56 bfdde42 eaa0bad cf8c147 bfdde42 4e59653 bfdde42 4e59653 3fdd1d7 eaa0bad 141a983 de78d1a eaa0bad de78d1a eaa0bad d80fd56 f4bcc28 cf8c147 f4bcc28 eaa0bad 8624b37 cf8c147 eaa0bad cf8c147 eaa0bad 18d1852 eaa0bad cf8c147 eaa0bad cf8c147 3fdd1d7 2250430 eaa0bad 2250430 eaa0bad cf8c147 753c201 cc825df d9364fd f72214b 2250430 a264403 12f08dc 753c201 2b3b1de 753c201 f72214b eaa0bad ffdb10e 96dd295 ffdb10e a812c7b ffdb10e 2250430 ffdb10e a812c7b ffdb10e 96dd295 a812c7b ffdb10e eaa0bad 96dd295 eaa0bad 96dd295 9becb2c da52f83 ffdb10e f72214b eaa0bad cf8c147 f72214b 2250430 7e798e5 87db14a f72214b 87da9a2 753c201 3fdd1d7 eaa0bad cf8c147 18d1852 9becb2c eaa0bad cf8c147 c70ac46 18d1852 3fdd1d7 eaa0bad cf8c147 18d1852 9becb2c 1fc0405 9a8f19a 11a17ef 2250430 9becb2c 0ed508e 18d1852 d6f8382 eaa0bad cf8c147 1a4044e cf8c147 18d1852 3fdd1d7 eaa0bad cf8c147 1a4044e cf8c147 18d1852 1f66159 c6c0b0e 87db14a 18d1852 3fdd1d7 eaa0bad cf8c147 18d1852 6a15fc4 18d1852 3fdd1d7 18d1852 cf8c147 18d1852 9becb2c 726be01 9becb2c 18d1852 afa0e9e f993077 2cd3e03 f993077 ba99375 436f3b4 afa0e9e 9becb2c eaa0bad 8d3df8c eaa0bad 9da8a28 eaa0bad 26f093e eaa0bad 26f093e eaa0bad 26f093e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
# Hints for methods
# initialize_state: Initializes default values for session state.
# set_up_widgets: Creates UI elements for model selection and settings.
# set_slider_value: Generates a slider widget for numerical input.
# is_widget_disabled: Returns True if UI elements should be disabled.
# disable_widgets: Disables interactive UI elements during processing.
# settings_changed: Checks if any model settings have changed.
# confidance_change: Determines if the confidence level setting has changed.
# display_model_settings: Shows current model settings in the UI.
# display_session_state: Displays the current state of the application.
# update_prev_state: Updates the record of the previous application state.
# force_reload_model: Reloads the model, clearing and resetting necessary states.
def __init__(self):
"""
Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
"""
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self):
"""
Initializes the Streamlit session state with default values for various keys.
"""
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if 'loading_in_progress' not in st.session_state:
st.session_state['loading_in_progress'] = False
if 'load_button_clicked' not in st.session_state:
st.session_state['load_button_clicked'] = False
if 'force_reload_button_clicked' not in st.session_state:
st.session_state['force_reload_button_clicked'] = False
if 'time_taken_to_load_model' not in st.session_state:
st.session_state['time_taken_to_load_model'] = None
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
if 'model_loaded' not in st.session_state:
st.session_state['model_loaded'] = self.is_model_loaded
def set_up_widgets(self) -> None:
"""
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
"""
self.col1.selectbox("Choose a model:", ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"], index=0, key='method', disabled=self.is_widget_disabled)
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', disabled=self.is_widget_disabled)
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
if show_model_settings:
self.display_model_settings
def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float, slider_key_name: str, col=None) -> None:
"""
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
Args:
text (str): Text to display next to the slider.
min_value (float): Minimum value for the slider.
max_value (float): Maximum value for the slider.
value (float): Initial value for the slider.
step (float): Step size for the slider.
slider_key_name (str): Unique key for the slider.
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
"""
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)
@property
def is_widget_disabled(self):
return st.session_state['loading_in_progress']
def disable_widgets(self):
"""
Disables widgets by setting the 'loading_in_progress' state to True.
"""
st.session_state['loading_in_progress'] = True
@property
def settings_changed(self):
"""
Checks if any model settings have changed compared to the previous state.
Returns:
bool: True if any setting has changed, False otherwise.
"""
return self.has_state_changed()
@property
def confidance_change(self):
"""
Checks if the confidence level setting has changed compared to the previous state.
Returns:
bool: True if the confidence level has changed, False otherwise.
"""
return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
def update_prev_state(self):
"""
Updates the 'previous_state' in the session state with the current state values.
"""
for key in st.session_state['previous_state']:
st.session_state['previous_state'][key] = st.session_state[key]
def load_model(self) -> None:
"""
Loads the KBVQA model based on the chosen method and settings.
- Frees GPU resources before loading.
- Calls `prepare_kbvqa_model` to create the model.
- Sets the detection confidence level on the model object.
- Updates previous state with current settings for change detection.
- Updates the button label to "Reload Model".
"""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def force_reload_model(self) -> None:
"""
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
- Deletes the current KBVQA model from the session state.
- Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
- Updates the detection confidence level on the model object.
- Displays a success message if the model is reloaded successfully.
"""
try:
self.delete_model()
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading model: {e}")
free_gpu_resources()
def delete_model(self) -> None:
"""
This method deletes the current models and calls `free_gpu_resources`.
"""
free_gpu_resources()
if self.is_model_loaded:
try:
del st.session_state['kbvqa']
free_gpu_resources()
except:
free_gpu_resources()
pass
# Function to check if any session state values have changed
def has_state_changed(self) -> bool:
"""
Compares current session state with the previous state to identify changes.
Returns:
bool: True if any change is found, False otherwise.
"""
for key in st.session_state['previous_state']:
if key == 'confidence_level':
continue # confidence_level tracker is separate
if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else: return False # No changes found
def get_model(self) -> KBVQA:
"""
Retrieve the KBVQA model from the session state.
Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
"""
return st.session_state.get('kbvqa', None)
@property
def is_model_loaded(self) -> bool:
"""
Checks if the KBVQA model is loaded in the session state.
Returns:
bool: True if the model is loaded, False otherwise.
"""
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and st.session_state.kbvqa.all_models_loaded
def reload_detection_model(self) -> None:
"""
Reloads only the detection model of the KBVQA model with updated settings.
- Frees GPU resources before reloading.
- Checks if the model is already loaded.
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
- Updates detection confidence level on the model object.
- Displays a success message if model is reloaded successfully.
"""
try:
free_gpu_resources()
if self.is_model_loaded:
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
self.update_prev_state
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key: str, image) -> None:
"""
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
This dictionary stores information about each processed image, including:
- `image`: The original image data.
- `caption`: Generated caption for the image.
- `detected_objects_str`: String representation of detected objects.
- `qa_history`: List of questions and answers related to the image.
- `analysis_done`: Flag indicating if analysis is complete.
Args:
image_key (str): Unique key for the image.
image (obj): The uploaded image data.
"""
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image) -> Tuple[str, str, object]:
"""
Analyzes the image using the KBVQA model.
- Creates a copy of the image to avoid modifying the original.
- Displays a "Analyzing the image .." message.
- Calls KBVQA methods to generate a caption and detect objects.
- Returns the generated caption, detected objects string, and image with bounding boxes.
Args:
image (obj): The image data to analyze.
Returns:
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
"""
img = copy.deepcopy(image)
caption = st.session_state['kbvqa'].get_caption(img)
image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
free_gpu_resources()
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
"""
Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
Args:
image_key (str): Unique key for the image.
question (str): The question asked about the image.
answer (str): The answer generated by the KBVQA model.
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
def get_images_data(self):
"""
Returns the dictionary containing processed image data from the session state.
Returns:
dict: The dictionary storing information about processed images.
"""
return st.session_state['images_data']
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
"""
Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
Args:
image_key (str): Unique key for the image.
caption (str): The generated caption for the image.
detected_objects_str (str): String representation of detected objects.
analysis_done (bool): Flag indicating if analysis of the image is complete.
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
def resize_image(self, image_input, new_width=None, new_height=None):
"""
Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
If both new_width and new_height are provided, the image is resized to those dimensions.
Args:
image (PIL.Image.Image): The image to resize.
new_width (int, optional): The target width of the image.
new_height (int, optional): The target height of the image.
Returns:
PIL.Image.Image: The resized image.
"""
img = copy.deepcopy(image_input)
if isinstance(img, str):
# Open the image from a file path
image = Image.open(img)
elif isinstance(img, Image.Image):
# Use the image directly if it's already a PIL Image object
image = img
else:
raise ValueError("image_input must be a file path or a PIL Image object")
if new_width is not None and new_height is None:
# Calculate new height to maintain aspect ratio
original_width, original_height = image.size
ratio = new_width / original_width
new_height = int(original_height * ratio)
elif new_width is None and new_height is not None:
# Calculate new width to maintain aspect ratio
original_width, original_height = image.size
ratio = new_height / original_height
new_width = int(original_width * ratio)
elif new_width is None and new_height is None:
raise ValueError("At least one of new_width or new_height must be provided")
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
def display_message(self, message, message_type):
if message_type == "warning":
st.warning(message)
elif message_type == "text":
st.text(message)
elif message_type == "success":
st.success(messae)
elif message_type == "write":
st.write(message)
else: st.error("Message type unknown")
@property
def display_model_settings(self):
"""
Displays a table of current model settings in the third column.
"""
self.col3.write("##### Current Model Settings:")
data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data' ]]
df = pd.DataFrame(data).reset_index(drop=True)
return self.col3.write(df)
def display_session_state(self, col):
"""
Displays a table of the complete application state..
"""
col.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data).reset_index(drop=True)
col.write(df)
|