KB-VQA / my_model /state_manager.py
m7mdal7aj's picture
Update my_model/state_manager.py
6740cd3 verified
raw
history blame
20 kB
# This module contains the StateManager class.
# The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
# and test the models.
import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict, Optional
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
"""
Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering
(KBVQA) application.
This class includes methods to initialize the session state, set up various UI widgets for model selection and
settings,
manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images.
It tracks changes to the application's state to ensure the correct configuration is maintained.
Additionally, it provides methods to display the current model settings and the complete application state within
the Streamlit interface.
The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
and test the models.
Attributes:
col1 (streamlit.columns): The first column in the Streamlit layout.
col2 (streamlit.columns): The second column in the Streamlit layout.
col3 (streamlit.columns): The third column in the Streamlit layout.
"""
def __init__(self) -> None:
"""
Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
"""
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self) -> None:
"""
Initializes the Streamlit session state with default values for various keys.
"""
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if 'loading_in_progress' not in st.session_state:
st.session_state['loading_in_progress'] = False
if 'load_button_clicked' not in st.session_state:
st.session_state['load_button_clicked'] = False
if 'force_reload_button_clicked' not in st.session_state:
st.session_state['force_reload_button_clicked'] = False
if 'time_taken_to_load_model' not in st.session_state:
st.session_state['time_taken_to_load_model'] = None
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
if 'model_loaded' not in st.session_state:
st.session_state['model_loaded'] = self.is_model_loaded
def set_up_widgets(self) -> None:
"""
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
Returns:
None
"""
self.col1.selectbox("Choose a model:",
["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"],
index=1, key='method', disabled=self.is_widget_disabled)
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1,
key='detection_model', disabled=self.is_widget_disabled)
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9,
value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
if show_model_settings:
self.display_model_settings
def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float,
slider_key_name: str, col=None) -> None:
"""
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
Args:
text (str): Text to display next to the slider.
min_value (float): Minimum value for the slider.
max_value (float): Maximum value for the slider.
value (float): Initial value for the slider.
step (float): Step size for the slider.
slider_key_name (str): Unique key for the slider.
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
Returns:
None
"""
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name,
disabled=self.is_widget_disabledd)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name,
disabled=self.is_widget_disabled)
@property
def is_widget_disabled(self) -> bool:
"""
Checks if widgets should be disabled based on the 'loading_in_progress' state.
Returns:
bool: True if widgets should be disabled, False otherwise.
"""
return st.session_state['loading_in_progress']
def disable_widgets(self) -> None:
"""
Disables widgets by setting the 'loading_in_progress' state to True.
Returns:
None
"""
st.session_state['loading_in_progress'] = True
@property
def settings_changed(self) -> bool:
"""
Checks if any model settings have changed compared to the previous state.
Returns:
bool: True if any setting has changed, False otherwise.
"""
return self.has_state_changed()
@property
def confidance_change(self) -> bool:
"""
Checks if the confidence level setting has changed compared to the previous state.
Returns:
bool: True if the confidence level has changed, False otherwise.
"""
return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
def update_prev_state(self) -> None:
"""
Updates the 'previous_state' in the session state with the current state values.
Returns:
None
"""
for key in st.session_state['previous_state']:
st.session_state['previous_state'][key] = st.session_state[key]
def load_model(self) -> None:
"""
Loads the KBVQA model based on the chosen method and settings.
- Frees GPU resources before loading.
- Calls `prepare_kbvqa_model` to create the model.
- Sets the detection confidence level on the model object.
- Updates previous state with current settings for change detection.
- Updates the button label to "Reload Model".
Returns:
None
"""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def force_reload_model(self) -> None:
"""
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls
`free_gpu_resources`.
- Deletes the current KBVQA model from the session state.
- Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
- Updates the detection confidence level on the model object.
- Displays a success message if the model is reloaded successfully.
Returns:
None
"""
try:
self.delete_model()
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading model: {e}")
free_gpu_resources()
def delete_model(self) -> None:
"""
This method deletes the current models and calls `free_gpu_resources`.
Returns:
None
"""
free_gpu_resources()
if self.is_model_loaded:
try:
del st.session_state['kbvqa']
free_gpu_resources()
free_gpu_resources()
except:
free_gpu_resources()
free_gpu_resources()
pass
def has_state_changed(self) -> bool:
"""
Compares current session state with the previous state to identify changes.
Returns:
bool: True if any change is found, False otherwise.
"""
for key in st.session_state['previous_state']:
if key == 'confidence_level':
continue # confidence_level tracker is separate
if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else:
return False # No changes found
def get_model(self) -> KBVQA:
"""
Retrieves the KBVQA model from the session state.
Returns:
KBVQA: The loaded KBVQA model, or None if not loaded.
"""
return st.session_state.get('kbvqa', None)
@property
def is_model_loaded(self) -> bool:
"""
Checks if the KBVQA model is loaded in the session state.
Returns:
bool: True if the model is loaded, False otherwise.
"""
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
st.session_state.kbvqa.all_models_loaded \
and (st.session_state['previous_state']['method'] is not None
and st.session_state['method'] == st.session_state['previous_state']['method'])
def reload_detection_model(self) -> None:
"""
Reloads only the detection model of the KBVQA model with updated settings.
- Frees GPU resources before reloading.
- Checks if the model is already loaded.
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
- Updates detection confidence level on the model object.
- Displays a success message if model is reloaded successfully.
Returns:
None
"""
try:
free_gpu_resources()
if self.is_model_loaded:
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
self.update_prev_state
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key: str, image) -> None:
"""
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session
state.
This dictionary stores information about each processed image, including:
- `image`: The original image data.
- `caption`: Generated caption for the image.
- `detected_objects_str`: String representation of detected objects.
- `qa_history`: List of questions and answers related to the image.
- `analysis_done`: Flag indicating if analysis is complete.
Args:
image_key (str): Unique key for the image.
image (obj): The uploaded image data.
Returns:
None
"""
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image) -> Tuple[str, str, object]:
"""
Analyzes the image using the KBVQA model.
- Creates a copy of the image to avoid modifying the original.
- Displays a "Analyzing the image .." message.
- Calls KBVQA methods to generate a caption and detect objects.
- Returns the generated caption, detected objects string, and image with bounding boxes.
Args:
image (obj): The image data to analyze.
Returns:
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
"""
free_gpu_resources()
free_gpu_resources()
img = copy.deepcopy(image)
caption = st.session_state['kbvqa'].get_caption(img)
image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
free_gpu_resources()
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
"""
Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.
Args:
image_key (str): Unique key for the image.
question (str): The question asked about the image.
answer (str): The answer generated by the KBVQA model.
prompt_length (int): The length of the prompt used for generating the answer.
Returns:
None
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
def get_images_data(self) -> Dict:
"""
Returns the dictionary containing processed image data from the session state.
Returns:
dict: The dictionary storing information about processed images.
"""
return st.session_state['images_data']
def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
"""
Updates the information stored for a specific image in the `images_data` dictionary in the application session
state.
Args:
image_key (str): Unique key for the image.
caption (str): The generated caption for the image.
detected_objects_str (str): String representation of detected objects.
analysis_done (bool): Flag indicating if analysis of the image is complete.
Returns:
None
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
"""
Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
If both new_width and new_height are provided, the image is resized to those dimensions.
Args:
image_input (PIL.Image.Image): The image to resize.
new_width (int, optional): The target width of the image.
new_height (int, optional): The target height of the image.
Returns:
PIL.Image.Image: The resized image.
"""
img = copy.deepcopy(image_input)
if isinstance(img, str):
# Open the image from a file path
image = Image.open(img)
elif isinstance(img, Image.Image):
# Use the image directly if it's already a PIL Image object
image = img
else:
raise ValueError("image_input must be a file path or a PIL Image object")
if new_width is not None and new_height is None:
# Calculate new height to maintain aspect ratio
original_width, original_height = image.size
ratio = new_width / original_width
new_height = int(original_height * ratio)
elif new_width is None and new_height is not None:
# Calculate new width to maintain aspect ratio
original_width, original_height = image.size
ratio = new_height / original_height
new_width = int(original_width * ratio)
elif new_width is None and new_height is None:
raise ValueError("At least one of new_width or new_height must be provided")
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
def display_message(self, message: str, message_type: str) -> None:
"""
Displays a message in the Streamlit interface based on the specified message type.
Args:
message (str): The message to display.
message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').
Returns:
None
"""
if message_type == "warning":
st.warning(message)
elif message_type == "text":
st.text(message)
elif message_type == "success":
st.success(message)
elif message_type == "write":
st.write(message)
else:
st.error("Message type unknown")
@property
def display_model_settings(self) -> None:
"""
Displays a table of current model settings in the third column.
Returns:
None
"""
self.col3.write("##### Current Model Settings:")
data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if
key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed',
'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data']]
df = pd.DataFrame(data).reset_index(drop=True)
return self.col3.write(df)
def display_session_state(self, col) -> None:
"""
Displays a table of the complete application state in the specified column.
Args:
col (streamlit.columns.Column): The Streamlit column to display the session state.
Returns:
None
"""
col.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data).reset_index(drop=True)
col.write(df)