Update my_model/state_manager.py
Browse files- my_model/state_manager.py +140 -5
my_model/state_manager.py
CHANGED
@@ -29,6 +29,9 @@ class StateManager:
|
|
29 |
|
30 |
|
31 |
def set_up_widgets(self):
|
|
|
|
|
|
|
32 |
|
33 |
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
34 |
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
@@ -45,6 +48,19 @@ class StateManager:
|
|
45 |
|
46 |
|
47 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
if col is None:
|
49 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
50 |
else:
|
@@ -53,10 +69,21 @@ class StateManager:
|
|
53 |
|
54 |
@property
|
55 |
def settings_changed(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
return self.has_state_changed()
|
57 |
|
58 |
|
59 |
def display_model_settings(self):
|
|
|
|
|
|
|
|
|
|
|
60 |
self.col3.write("##### Current Model Settings:")
|
61 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
62 |
df = pd.DataFrame(data)
|
@@ -65,6 +92,10 @@ class StateManager:
|
|
65 |
|
66 |
|
67 |
def display_session_state(self):
|
|
|
|
|
|
|
|
|
68 |
st.write("Current Model:")
|
69 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
70 |
df = pd.DataFrame(data)
|
@@ -72,7 +103,16 @@ class StateManager:
|
|
72 |
|
73 |
|
74 |
def load_model(self):
|
75 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
try:
|
77 |
free_gpu_resources()
|
78 |
st.session_state['kbvqa'] = prepare_kbvqa_model()
|
@@ -91,6 +131,12 @@ class StateManager:
|
|
91 |
|
92 |
# Function to check if any session state values have changed
|
93 |
def has_state_changed(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
for key in st.session_state['previous_state']:
|
95 |
if st.session_state[key] != st.session_state['previous_state'][key]:
|
96 |
return True # Found a change
|
@@ -98,15 +144,35 @@ class StateManager:
|
|
98 |
|
99 |
|
100 |
def get_model(self):
|
101 |
-
"""
|
|
|
|
|
|
|
|
|
102 |
return st.session_state.get('kbvqa', None)
|
103 |
|
104 |
|
105 |
def is_model_loaded(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
107 |
|
108 |
|
109 |
def reload_detection_model(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
try:
|
111 |
free_gpu_resources()
|
112 |
if self.is_model_loaded():
|
@@ -119,6 +185,22 @@ class StateManager:
|
|
119 |
|
120 |
|
121 |
def process_new_image(self, image_key, image, kbvqa):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
if image_key not in st.session_state['images_data']:
|
123 |
st.session_state['images_data'][image_key] = {
|
124 |
'image': image,
|
@@ -130,6 +212,21 @@ class StateManager:
|
|
130 |
|
131 |
|
132 |
def analyze_image(self, image, kbvqa):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
img = copy.deepcopy(image)
|
134 |
st.text("Analyzing the image .. ")
|
135 |
caption = kbvqa.get_caption(img)
|
@@ -138,22 +235,60 @@ class StateManager:
|
|
138 |
|
139 |
|
140 |
def add_to_qa_history(self, image_key, question, answer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
if image_key in st.session_state['images_data']:
|
142 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
143 |
|
144 |
|
145 |
def get_images_data(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
return st.session_state['images_data']
|
147 |
|
148 |
@staticmethod
|
149 |
def resize_image(image_path, new_width, new_height):
|
150 |
-
"""
|
151 |
-
image
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
return resized_image
|
154 |
|
155 |
|
156 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
if image_key in st.session_state['images_data']:
|
158 |
st.session_state['images_data'][image_key].update({
|
159 |
'caption': caption,
|
|
|
29 |
|
30 |
|
31 |
def set_up_widgets(self):
|
32 |
+
"""
|
33 |
+
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
|
34 |
+
"""
|
35 |
|
36 |
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
37 |
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
|
|
48 |
|
49 |
|
50 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
|
51 |
+
"""
|
52 |
+
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
text (str): Text to display next to the slider.
|
56 |
+
min_value (float): Minimum value for the slider.
|
57 |
+
max_value (float): Maximum value for the slider.
|
58 |
+
value (float): Initial value for the slider.
|
59 |
+
step (float): Step size for the slider.
|
60 |
+
slider_key_name (str): Unique key for the slider.
|
61 |
+
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
|
62 |
+
"""
|
63 |
+
|
64 |
if col is None:
|
65 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
66 |
else:
|
|
|
69 |
|
70 |
@property
|
71 |
def settings_changed(self):
|
72 |
+
"""
|
73 |
+
Checks if any model settings have changed compared to the previous state.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
bool: True if any setting has changed, False otherwise.
|
77 |
+
"""
|
78 |
return self.has_state_changed()
|
79 |
|
80 |
|
81 |
def display_model_settings(self):
|
82 |
+
"""
|
83 |
+
Displays a table of current model settings in the third column.
|
84 |
+
|
85 |
+
Uses formatted HTML to style the table for better readability.
|
86 |
+
"""
|
87 |
self.col3.write("##### Current Model Settings:")
|
88 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
89 |
df = pd.DataFrame(data)
|
|
|
92 |
|
93 |
|
94 |
def display_session_state(self):
|
95 |
+
"""
|
96 |
+
Displays a table of the complete application state..
|
97 |
+
"""
|
98 |
+
|
99 |
st.write("Current Model:")
|
100 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
101 |
df = pd.DataFrame(data)
|
|
|
103 |
|
104 |
|
105 |
def load_model(self):
|
106 |
+
"""
|
107 |
+
Loads the KBVQA model based on the chosen method and settings.
|
108 |
+
|
109 |
+
- Frees GPU resources before loading.
|
110 |
+
- Calls `prepare_kbvqa_model` to create the model.
|
111 |
+
- Sets the detection confidence level on the model object.
|
112 |
+
- Updates previous state with current settings for change detection.
|
113 |
+
- Updates the button label to "Reload Model".
|
114 |
+
"""
|
115 |
+
|
116 |
try:
|
117 |
free_gpu_resources()
|
118 |
st.session_state['kbvqa'] = prepare_kbvqa_model()
|
|
|
131 |
|
132 |
# Function to check if any session state values have changed
|
133 |
def has_state_changed(self):
|
134 |
+
"""
|
135 |
+
Compares current session state with the previous state to identify changes.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
bool: True if any change is found, False otherwise.
|
139 |
+
"""
|
140 |
for key in st.session_state['previous_state']:
|
141 |
if st.session_state[key] != st.session_state['previous_state'][key]:
|
142 |
return True # Found a change
|
|
|
144 |
|
145 |
|
146 |
def get_model(self):
|
147 |
+
"""
|
148 |
+
Retrieve the KBVQA model from the session state.
|
149 |
+
|
150 |
+
Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
|
151 |
+
"""
|
152 |
return st.session_state.get('kbvqa', None)
|
153 |
|
154 |
|
155 |
def is_model_loaded(self):
|
156 |
+
"""
|
157 |
+
Checks if the KBVQA model is loaded in the session state.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
bool: True if the model is loaded, False otherwise.
|
161 |
+
"""
|
162 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
163 |
|
164 |
|
165 |
def reload_detection_model(self):
|
166 |
+
"""
|
167 |
+
Reloads only the detection model of the KBVQA model with updated settings.
|
168 |
+
|
169 |
+
- Frees GPU resources before reloading.
|
170 |
+
- Checks if the model is already loaded.
|
171 |
+
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
|
172 |
+
- Updates detection confidence level on the model object.
|
173 |
+
- Displays a success message if model is reloaded successfully.
|
174 |
+
"""
|
175 |
+
|
176 |
try:
|
177 |
free_gpu_resources()
|
178 |
if self.is_model_loaded():
|
|
|
185 |
|
186 |
|
187 |
def process_new_image(self, image_key, image, kbvqa):
|
188 |
+
"""
|
189 |
+
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
|
190 |
+
|
191 |
+
This dictionary stores information about each processed image, including:
|
192 |
+
- `image`: The original image data.
|
193 |
+
- `caption`: Generated caption for the image.
|
194 |
+
- `detected_objects_str`: String representation of detected objects.
|
195 |
+
- `qa_history`: List of questions and answers related to the image.
|
196 |
+
- `analysis_done`: Flag indicating if analysis is complete.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
image_key (str): Unique key for the image.
|
200 |
+
image (obj): The uploaded image data.
|
201 |
+
kbvqa (KBVQA object): The loaded KBVQA model.
|
202 |
+
"""
|
203 |
+
|
204 |
if image_key not in st.session_state['images_data']:
|
205 |
st.session_state['images_data'][image_key] = {
|
206 |
'image': image,
|
|
|
212 |
|
213 |
|
214 |
def analyze_image(self, image, kbvqa):
|
215 |
+
"""
|
216 |
+
Analyzes the image using the KBVQA model.
|
217 |
+
|
218 |
+
- Creates a copy of the image to avoid modifying the original.
|
219 |
+
- Displays a "Analyzing the image .." message.
|
220 |
+
- Calls KBVQA methods to generate a caption and detect objects.
|
221 |
+
- Returns the generated caption, detected objects string, and image with bounding boxes.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
image (obj): The image data to analyze.
|
225 |
+
kbvqa (KBVQA object): The loaded KBVQA model.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
|
229 |
+
"""
|
230 |
img = copy.deepcopy(image)
|
231 |
st.text("Analyzing the image .. ")
|
232 |
caption = kbvqa.get_caption(img)
|
|
|
235 |
|
236 |
|
237 |
def add_to_qa_history(self, image_key, question, answer):
|
238 |
+
"""
|
239 |
+
Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
image_key (str): Unique key for the image.
|
243 |
+
question (str): The question asked about the image.
|
244 |
+
answer (str): The answer generated by the KBVQA model.
|
245 |
+
"""
|
246 |
if image_key in st.session_state['images_data']:
|
247 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
248 |
|
249 |
|
250 |
def get_images_data(self):
|
251 |
+
"""
|
252 |
+
Returns the dictionary containing processed image data from the session state.
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
dict: The dictionary storing information about processed images.
|
256 |
+
"""
|
257 |
return st.session_state['images_data']
|
258 |
|
259 |
@staticmethod
|
260 |
def resize_image(image_path, new_width, new_height):
|
261 |
+
"""
|
262 |
+
Resizes an image from the specified to the given dimensions.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
image_path (str): Path to the image file.
|
266 |
+
new_width (int): Desired width for the resized image.
|
267 |
+
new_height (int): Desired height for the resized image.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Image: The resized image object.
|
271 |
+
"""
|
272 |
+
|
273 |
+
if isinstance(image_path, str):
|
274 |
+
# Open the image from a file path
|
275 |
+
image = Image.open(image_path)
|
276 |
+
elif hasattr(image_path, 'read'):
|
277 |
+
resized_image = image.resize((new_width, new_height))
|
278 |
+
|
279 |
return resized_image
|
280 |
|
281 |
|
282 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
283 |
+
"""
|
284 |
+
Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
image_key (str): Unique key for the image.
|
288 |
+
caption (str): The generated caption for the image.
|
289 |
+
detected_objects_str (str): String representation of detected objects.
|
290 |
+
analysis_done (bool): Flag indicating if analysis of the image is complete.
|
291 |
+
"""
|
292 |
if image_key in st.session_state['images_data']:
|
293 |
st.session_state['images_data'][image_key].update({
|
294 |
'caption': caption,
|