m7mdal7aj commited on
Commit
cf8c147
1 Parent(s): 29f316e

Update my_model/state_manager.py

Browse files
Files changed (1) hide show
  1. my_model/state_manager.py +140 -5
my_model/state_manager.py CHANGED
@@ -29,6 +29,9 @@ class StateManager:
29
 
30
 
31
  def set_up_widgets(self):
 
 
 
32
 
33
  self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
34
  detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
@@ -45,6 +48,19 @@ class StateManager:
45
 
46
 
47
  def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if col is None:
49
  return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
50
  else:
@@ -53,10 +69,21 @@ class StateManager:
53
 
54
  @property
55
  def settings_changed(self):
 
 
 
 
 
 
56
  return self.has_state_changed()
57
 
58
 
59
  def display_model_settings(self):
 
 
 
 
 
60
  self.col3.write("##### Current Model Settings:")
61
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
62
  df = pd.DataFrame(data)
@@ -65,6 +92,10 @@ class StateManager:
65
 
66
 
67
  def display_session_state(self):
 
 
 
 
68
  st.write("Current Model:")
69
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
70
  df = pd.DataFrame(data)
@@ -72,7 +103,16 @@ class StateManager:
72
 
73
 
74
  def load_model(self):
75
- """Load the KBVQA model with specified settings."""
 
 
 
 
 
 
 
 
 
76
  try:
77
  free_gpu_resources()
78
  st.session_state['kbvqa'] = prepare_kbvqa_model()
@@ -91,6 +131,12 @@ class StateManager:
91
 
92
  # Function to check if any session state values have changed
93
  def has_state_changed(self):
 
 
 
 
 
 
94
  for key in st.session_state['previous_state']:
95
  if st.session_state[key] != st.session_state['previous_state'][key]:
96
  return True # Found a change
@@ -98,15 +144,35 @@ class StateManager:
98
 
99
 
100
  def get_model(self):
101
- """Retrieve the KBVQA model from the session state."""
 
 
 
 
102
  return st.session_state.get('kbvqa', None)
103
 
104
 
105
  def is_model_loaded(self):
 
 
 
 
 
 
106
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
107
 
108
 
109
  def reload_detection_model(self):
 
 
 
 
 
 
 
 
 
 
110
  try:
111
  free_gpu_resources()
112
  if self.is_model_loaded():
@@ -119,6 +185,22 @@ class StateManager:
119
 
120
 
121
  def process_new_image(self, image_key, image, kbvqa):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  if image_key not in st.session_state['images_data']:
123
  st.session_state['images_data'][image_key] = {
124
  'image': image,
@@ -130,6 +212,21 @@ class StateManager:
130
 
131
 
132
  def analyze_image(self, image, kbvqa):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  img = copy.deepcopy(image)
134
  st.text("Analyzing the image .. ")
135
  caption = kbvqa.get_caption(img)
@@ -138,22 +235,60 @@ class StateManager:
138
 
139
 
140
  def add_to_qa_history(self, image_key, question, answer):
 
 
 
 
 
 
 
 
141
  if image_key in st.session_state['images_data']:
142
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
143
 
144
 
145
  def get_images_data(self):
 
 
 
 
 
 
146
  return st.session_state['images_data']
147
 
148
  @staticmethod
149
  def resize_image(image_path, new_width, new_height):
150
- """Resize an image to the specified width and height."""
151
- image = Image.open(image_path)
152
- resized_image = image.resize((new_width, new_height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  return resized_image
154
 
155
 
156
  def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
 
 
 
 
 
 
 
 
 
157
  if image_key in st.session_state['images_data']:
158
  st.session_state['images_data'][image_key].update({
159
  'caption': caption,
 
29
 
30
 
31
  def set_up_widgets(self):
32
+ """
33
+ Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
34
+ """
35
 
36
  self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
37
  detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
 
48
 
49
 
50
  def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
51
+ """
52
+ Creates a slider widget with the specified parameters, optionally placing it in a specific column.
53
+
54
+ Args:
55
+ text (str): Text to display next to the slider.
56
+ min_value (float): Minimum value for the slider.
57
+ max_value (float): Maximum value for the slider.
58
+ value (float): Initial value for the slider.
59
+ step (float): Step size for the slider.
60
+ slider_key_name (str): Unique key for the slider.
61
+ col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
62
+ """
63
+
64
  if col is None:
65
  return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
66
  else:
 
69
 
70
  @property
71
  def settings_changed(self):
72
+ """
73
+ Checks if any model settings have changed compared to the previous state.
74
+
75
+ Returns:
76
+ bool: True if any setting has changed, False otherwise.
77
+ """
78
  return self.has_state_changed()
79
 
80
 
81
  def display_model_settings(self):
82
+ """
83
+ Displays a table of current model settings in the third column.
84
+
85
+ Uses formatted HTML to style the table for better readability.
86
+ """
87
  self.col3.write("##### Current Model Settings:")
88
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
89
  df = pd.DataFrame(data)
 
92
 
93
 
94
  def display_session_state(self):
95
+ """
96
+ Displays a table of the complete application state..
97
+ """
98
+
99
  st.write("Current Model:")
100
  data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
101
  df = pd.DataFrame(data)
 
103
 
104
 
105
  def load_model(self):
106
+ """
107
+ Loads the KBVQA model based on the chosen method and settings.
108
+
109
+ - Frees GPU resources before loading.
110
+ - Calls `prepare_kbvqa_model` to create the model.
111
+ - Sets the detection confidence level on the model object.
112
+ - Updates previous state with current settings for change detection.
113
+ - Updates the button label to "Reload Model".
114
+ """
115
+
116
  try:
117
  free_gpu_resources()
118
  st.session_state['kbvqa'] = prepare_kbvqa_model()
 
131
 
132
  # Function to check if any session state values have changed
133
  def has_state_changed(self):
134
+ """
135
+ Compares current session state with the previous state to identify changes.
136
+
137
+ Returns:
138
+ bool: True if any change is found, False otherwise.
139
+ """
140
  for key in st.session_state['previous_state']:
141
  if st.session_state[key] != st.session_state['previous_state'][key]:
142
  return True # Found a change
 
144
 
145
 
146
  def get_model(self):
147
+ """
148
+ Retrieve the KBVQA model from the session state.
149
+
150
+ Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
151
+ """
152
  return st.session_state.get('kbvqa', None)
153
 
154
 
155
  def is_model_loaded(self):
156
+ """
157
+ Checks if the KBVQA model is loaded in the session state.
158
+
159
+ Returns:
160
+ bool: True if the model is loaded, False otherwise.
161
+ """
162
  return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
163
 
164
 
165
  def reload_detection_model(self):
166
+ """
167
+ Reloads only the detection model of the KBVQA model with updated settings.
168
+
169
+ - Frees GPU resources before reloading.
170
+ - Checks if the model is already loaded.
171
+ - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
172
+ - Updates detection confidence level on the model object.
173
+ - Displays a success message if model is reloaded successfully.
174
+ """
175
+
176
  try:
177
  free_gpu_resources()
178
  if self.is_model_loaded():
 
185
 
186
 
187
  def process_new_image(self, image_key, image, kbvqa):
188
+ """
189
+ Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
190
+
191
+ This dictionary stores information about each processed image, including:
192
+ - `image`: The original image data.
193
+ - `caption`: Generated caption for the image.
194
+ - `detected_objects_str`: String representation of detected objects.
195
+ - `qa_history`: List of questions and answers related to the image.
196
+ - `analysis_done`: Flag indicating if analysis is complete.
197
+
198
+ Args:
199
+ image_key (str): Unique key for the image.
200
+ image (obj): The uploaded image data.
201
+ kbvqa (KBVQA object): The loaded KBVQA model.
202
+ """
203
+
204
  if image_key not in st.session_state['images_data']:
205
  st.session_state['images_data'][image_key] = {
206
  'image': image,
 
212
 
213
 
214
  def analyze_image(self, image, kbvqa):
215
+ """
216
+ Analyzes the image using the KBVQA model.
217
+
218
+ - Creates a copy of the image to avoid modifying the original.
219
+ - Displays a "Analyzing the image .." message.
220
+ - Calls KBVQA methods to generate a caption and detect objects.
221
+ - Returns the generated caption, detected objects string, and image with bounding boxes.
222
+
223
+ Args:
224
+ image (obj): The image data to analyze.
225
+ kbvqa (KBVQA object): The loaded KBVQA model.
226
+
227
+ Returns:
228
+ tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
229
+ """
230
  img = copy.deepcopy(image)
231
  st.text("Analyzing the image .. ")
232
  caption = kbvqa.get_caption(img)
 
235
 
236
 
237
  def add_to_qa_history(self, image_key, question, answer):
238
+ """
239
+ Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
240
+
241
+ Args:
242
+ image_key (str): Unique key for the image.
243
+ question (str): The question asked about the image.
244
+ answer (str): The answer generated by the KBVQA model.
245
+ """
246
  if image_key in st.session_state['images_data']:
247
  st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
248
 
249
 
250
  def get_images_data(self):
251
+ """
252
+ Returns the dictionary containing processed image data from the session state.
253
+
254
+ Returns:
255
+ dict: The dictionary storing information about processed images.
256
+ """
257
  return st.session_state['images_data']
258
 
259
  @staticmethod
260
  def resize_image(image_path, new_width, new_height):
261
+ """
262
+ Resizes an image from the specified to the given dimensions.
263
+
264
+ Args:
265
+ image_path (str): Path to the image file.
266
+ new_width (int): Desired width for the resized image.
267
+ new_height (int): Desired height for the resized image.
268
+
269
+ Returns:
270
+ Image: The resized image object.
271
+ """
272
+
273
+ if isinstance(image_path, str):
274
+ # Open the image from a file path
275
+ image = Image.open(image_path)
276
+ elif hasattr(image_path, 'read'):
277
+ resized_image = image.resize((new_width, new_height))
278
+
279
  return resized_image
280
 
281
 
282
  def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
283
+ """
284
+ Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
285
+
286
+ Args:
287
+ image_key (str): Unique key for the image.
288
+ caption (str): The generated caption for the image.
289
+ detected_objects_str (str): String representation of detected objects.
290
+ analysis_done (bool): Flag indicating if analysis of the image is complete.
291
+ """
292
  if image_key in st.session_state['images_data']:
293
  st.session_state['images_data'][image_key].update({
294
  'caption': caption,