m7mdal7aj commited on
Commit
7337830
1 Parent(s): 61382de

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +73 -66
my_model/tabs/run_inference.py CHANGED
@@ -17,14 +17,17 @@ from my_model.state_manager import StateManager
17
  from my_model.config import inference_config as config
18
 
19
 
 
20
  class InferenceRunner(StateManager):
21
  """
22
- Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual Question Answering (KBVQA) application.
23
-
24
- This class handles image uploads, displays sample images, and facilitates the question-answering process using the KBVQA model.
 
 
25
  Inherits from the StateManager class.
26
  """
27
-
28
  def __init__(self) -> None:
29
  """
30
  Initializes the InferenceRunner instance, setting up the necessary state.
@@ -32,7 +35,6 @@ class InferenceRunner(StateManager):
32
 
33
  super().__init__()
34
 
35
-
36
  def answer_question(self, caption: str, detected_objects_str: str, question: str) -> Tuple[str, int]:
37
  """
38
  Generates an answer to a user's question based on the image's caption and detected objects.
@@ -45,14 +47,13 @@ class InferenceRunner(StateManager):
45
  Returns:
46
  Tuple[str, int]: A tuple containing the answer to the question and the prompt length.
47
  """
48
-
49
  free_gpu_resources()
50
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
51
- prompt_length = st.session_state.kbvqa.current_prompt_length
52
  free_gpu_resources()
53
  return answer, prompt_length
54
 
55
-
56
  def display_sample_images(self) -> None:
57
  """
58
  Displays sample images as clickable thumbnails for the user to select.
@@ -60,7 +61,7 @@ class InferenceRunner(StateManager):
60
  Returns:
61
  None
62
  """
63
-
64
  self.col1.write("Choose from sample images:")
65
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
66
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
@@ -68,10 +69,9 @@ class InferenceRunner(StateManager):
68
  image = Image.open(sample_image_path)
69
  image_for_display = self.resize_image(sample_image_path, 80, 80)
70
  st.image(image_for_display)
71
- if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx+1}'):
72
  self.process_new_image(sample_image_path, image)
73
 
74
-
75
  def handle_image_upload(self) -> None:
76
  """
77
  Provides an image uploader widget for the user to upload their own images.
@@ -79,13 +79,13 @@ class InferenceRunner(StateManager):
79
  Returns:
80
  None
81
  """
82
-
83
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
84
  if uploaded_image is not None:
85
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
86
-
87
 
88
- def display_image_and_analysis(self, image_key: str, image_data: Dict, nested_col21: DeltaGenerator, nested_col22: DeltaGenerator) -> None:
 
89
  """
90
  Displays the uploaded or selected image and provides an option to analyze the image.
91
 
@@ -94,15 +94,14 @@ class InferenceRunner(StateManager):
94
  image_data (Dict): Data associated with the image.
95
  nested_col21 (DeltaGenerator): Column for displaying the image.
96
  nested_col22 (DeltaGenerator): Column for displaying the analysis button.
97
-
98
  Returns:
99
  None
100
  """
101
-
102
  image_for_display = self.resize_image(image_data['image'], 600)
103
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
104
  self.handle_analysis_button(image_key, image_data, nested_col22)
105
-
106
 
107
  def handle_analysis_button(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
108
  """
@@ -112,22 +111,23 @@ class InferenceRunner(StateManager):
112
  image_key (str): Unique key identifying the image.
113
  image_data (Dict): Data associated with the image.
114
  nested_col22 (DeltaGenerator): Column for displaying the analysis button.
115
-
116
  Returns:
117
  None
118
  """
119
-
120
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
121
  nested_col22.text("Please click 'Analyze Image'..")
122
- analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}'
 
123
  with nested_col22:
124
- if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
 
125
  with st.spinner('Analyzing the image...'):
126
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
127
  self.update_image_data(image_key, caption, detected_objects_str, True)
128
  st.session_state['loading_in_progress'] = False
129
 
130
-
131
  def handle_question_answering(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
132
  """
133
  Manages the question-answering interface for each image.
@@ -136,19 +136,19 @@ class InferenceRunner(StateManager):
136
  image_key (str): Unique key identifying the image.
137
  image_data (Dict): Data associated with the image.
138
  nested_col22 (DeltaGenerator): Column for displaying the question-answering interface.
139
-
140
  Returns:
141
  None
142
  """
143
-
144
  if image_data['analysis_done']:
145
  self.display_question_answering_interface(image_key, image_data, nested_col22)
146
 
147
  if self.settings_changed or self.confidance_change:
148
  nested_col22.warning("Confidence level changed, please click 'Analyze Image' each time you change it.")
149
 
150
-
151
- def display_question_answering_interface(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
152
  """
153
  Displays the interface for question answering, including sample questions and a custom question input.
154
 
@@ -156,31 +156,33 @@ class InferenceRunner(StateManager):
156
  image_key (str): Unique key identifying the image.
157
  image_data (Dict): Data associated with the image.
158
  nested_col22 (DeltaGenerator): The column where the interface will be displayed.
159
-
160
  Returns:
161
  None
162
  """
163
-
164
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
165
- selected_question = nested_col22.selectbox("Select a sample question or type your own:", ["Custom question..."] + sample_questions, key=f'sample_question_{image_key}')
166
-
 
 
167
  # Display custom question input only if "Custom question..." is selected
168
  question = selected_question
169
  if selected_question == "Custom question...":
170
  custom_question = nested_col22.text_input("Or ask your own question:", key=f'custom_question_{image_key}')
171
  question = custom_question
172
-
173
  self.process_question(image_key, question, image_data, nested_col22)
174
-
175
  qa_history = image_data.get('qa_history', [])
176
  for num, (q, a, p) in enumerate(qa_history):
177
- nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
178
-
179
 
180
  def process_question(self, image_key: str, question: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
181
  """
182
  Processes the user's question, generates an answer, and updates the question-answer history.
183
- This method checks if the question is new or if settings have changed, and if so, generates an answer using the KBVQA model.
 
184
  It then updates the question-answer history for the image.
185
 
186
  Args:
@@ -188,73 +190,76 @@ class InferenceRunner(StateManager):
188
  question (str): The question asked by the user.
189
  image_data (Dict): Data associated with the image.
190
  nested_col22 (DeltaGenerator): The column where the answer will be displayed.
191
-
192
  Returns:
193
  None
194
  """
195
-
196
  qa_history = image_data.get('qa_history', [])
197
- if question and (question not in [q for q, _, _ in qa_history] or self.settings_changed or self.confidance_change):
 
198
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
199
- answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
 
200
  self.add_to_qa_history(image_key, question, answer, prompt_length)
201
-
202
 
203
  def image_qa_app(self) -> None:
204
  """
205
  Main application interface for image-based question answering.
206
 
207
- This method orchestrates the display of sample images, handles image uploads, and facilitates the question-answering process.
208
- It iterates through each image in the session state, displaying the image and providing interfaces for image analysis and question answering.
209
-
 
 
210
  Returns:
211
  None
212
  """
213
-
214
  self.display_sample_images()
215
  self.handle_image_upload()
216
- #self.display_session_state(self.col1)
217
  with self.col2:
218
  for image_key, image_data in self.get_images_data().items():
219
  with st.container():
220
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
221
  self.display_image_and_analysis(image_key, image_data, nested_col21, nested_col22)
222
  self.handle_question_answering(image_key, image_data, nested_col22)
223
-
224
-
225
  def run_inference(self) -> None:
226
  """
227
- Sets up widgets and manages the inference process, including model loading and reloading, based on user interactions.
 
228
 
229
  This method orchestrates the overall flow of the inference process.
230
-
231
  Returns:
232
  None
233
  """
234
 
235
  self.set_up_widgets() # Inherent from the StateManager Class
236
-
237
  load_fine_tuned_model = False
238
  fine_tuned_model_already_loaded = False
239
  reload_detection_model = False
240
  force_reload_full_model = False
241
-
242
  if self.is_model_loaded and self.settings_changed:
243
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
244
  self.update_prev_state()
245
  st.session_state.button_label = (
246
- "Reload Model" if (self.is_model_loaded and
247
- st.session_state.kbvqa.detection_model != st.session_state['detection_model'] and
248
- self.settings_changed())
249
- else "Load Model"
250
- )
251
 
252
-
253
  with self.col1:
254
  if st.session_state.method == "7b-Fine-Tuned Model" or st.session_state.method == "13b-Fine-Tuned Model":
255
  with st.container():
256
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
257
- if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
 
258
  if st.session_state.button_label == "Load Model":
259
  if self.is_model_loaded:
260
  free_gpu_resources()
@@ -263,13 +268,14 @@ class InferenceRunner(StateManager):
263
  load_fine_tuned_model = True
264
  else:
265
  reload_detection_model = True
266
- if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
 
267
  force_reload_full_model = True
268
  if load_fine_tuned_model:
269
- t1=time.time()
270
  free_gpu_resources()
271
  self.load_model()
272
- st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
273
  st.session_state['loading_in_progress'] = False
274
  elif fine_tuned_model_already_loaded:
275
  free_gpu_resources()
@@ -281,16 +287,17 @@ class InferenceRunner(StateManager):
281
  st.session_state['loading_in_progress'] = False
282
  elif force_reload_full_model:
283
  free_gpu_resources()
284
- t1=time.time()
285
  self.force_reload_model()
286
- st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
287
  st.session_state['loading_in_progress'] = False
288
  st.session_state['model_loaded'] = True
289
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
290
- self.col1.warning(f'Model using {st.session_state.method} is desgined but requires large scale data and multiple high-end GPUs, implementation will be explored in the future.')
 
 
291
  if self.is_model_loaded:
292
  free_gpu_resources()
293
  st.session_state['loading_in_progress'] = False
294
- self.image_qa_app() # this is the main Q/A Application
295
-
296
 
 
17
  from my_model.config import inference_config as config
18
 
19
 
20
+
21
  class InferenceRunner(StateManager):
22
  """
23
+ Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual
24
+ Question Answering (KBVQA) application.
25
+
26
+ This class handles image uploads, displays sample images, and facilitates the question-answering process using the
27
+ KBVQA model.
28
  Inherits from the StateManager class.
29
  """
30
+
31
  def __init__(self) -> None:
32
  """
33
  Initializes the InferenceRunner instance, setting up the necessary state.
 
35
 
36
  super().__init__()
37
 
 
38
  def answer_question(self, caption: str, detected_objects_str: str, question: str) -> Tuple[str, int]:
39
  """
40
  Generates an answer to a user's question based on the image's caption and detected objects.
 
47
  Returns:
48
  Tuple[str, int]: A tuple containing the answer to the question and the prompt length.
49
  """
50
+
51
  free_gpu_resources()
52
  answer = st.session_state.kbvqa.generate_answer(question, caption, detected_objects_str)
53
+ prompt_length = st.session_state.kbvqa.current_prompt_length
54
  free_gpu_resources()
55
  return answer, prompt_length
56
 
 
57
  def display_sample_images(self) -> None:
58
  """
59
  Displays sample images as clickable thumbnails for the user to select.
 
61
  Returns:
62
  None
63
  """
64
+
65
  self.col1.write("Choose from sample images:")
66
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
67
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
 
69
  image = Image.open(sample_image_path)
70
  image_for_display = self.resize_image(sample_image_path, 80, 80)
71
  st.image(image_for_display)
72
+ if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx + 1}'):
73
  self.process_new_image(sample_image_path, image)
74
 
 
75
  def handle_image_upload(self) -> None:
76
  """
77
  Provides an image uploader widget for the user to upload their own images.
 
79
  Returns:
80
  None
81
  """
82
+
83
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
84
  if uploaded_image is not None:
85
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image))
 
86
 
87
+ def display_image_and_analysis(self, image_key: str, image_data: Dict, nested_col21: DeltaGenerator,
88
+ nested_col22: DeltaGenerator) -> None:
89
  """
90
  Displays the uploaded or selected image and provides an option to analyze the image.
91
 
 
94
  image_data (Dict): Data associated with the image.
95
  nested_col21 (DeltaGenerator): Column for displaying the image.
96
  nested_col22 (DeltaGenerator): Column for displaying the analysis button.
97
+
98
  Returns:
99
  None
100
  """
101
+
102
  image_for_display = self.resize_image(image_data['image'], 600)
103
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
104
  self.handle_analysis_button(image_key, image_data, nested_col22)
 
105
 
106
  def handle_analysis_button(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
107
  """
 
111
  image_key (str): Unique key identifying the image.
112
  image_data (Dict): Data associated with the image.
113
  nested_col22 (DeltaGenerator): Column for displaying the analysis button.
114
+
115
  Returns:
116
  None
117
  """
118
+
119
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change:
120
  nested_col22.text("Please click 'Analyze Image'..")
121
+ analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_' \
122
+ f'{st.session_state.confidence_level}'
123
  with nested_col22:
124
+ if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets,
125
+ disabled=self.is_widget_disabled):
126
  with st.spinner('Analyzing the image...'):
127
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
128
  self.update_image_data(image_key, caption, detected_objects_str, True)
129
  st.session_state['loading_in_progress'] = False
130
 
 
131
  def handle_question_answering(self, image_key: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
132
  """
133
  Manages the question-answering interface for each image.
 
136
  image_key (str): Unique key identifying the image.
137
  image_data (Dict): Data associated with the image.
138
  nested_col22 (DeltaGenerator): Column for displaying the question-answering interface.
139
+
140
  Returns:
141
  None
142
  """
143
+
144
  if image_data['analysis_done']:
145
  self.display_question_answering_interface(image_key, image_data, nested_col22)
146
 
147
  if self.settings_changed or self.confidance_change:
148
  nested_col22.warning("Confidence level changed, please click 'Analyze Image' each time you change it.")
149
 
150
+ def display_question_answering_interface(self, image_key: str, image_data: Dict,
151
+ nested_col22: DeltaGenerator) -> None:
152
  """
153
  Displays the interface for question answering, including sample questions and a custom question input.
154
 
 
156
  image_key (str): Unique key identifying the image.
157
  image_data (Dict): Data associated with the image.
158
  nested_col22 (DeltaGenerator): The column where the interface will be displayed.
159
+
160
  Returns:
161
  None
162
  """
163
+
164
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
165
+ selected_question = nested_col22.selectbox("Select a sample question or type your own:",
166
+ ["Custom question..."] + sample_questions,
167
+ key=f'sample_question_{image_key}')
168
+
169
  # Display custom question input only if "Custom question..." is selected
170
  question = selected_question
171
  if selected_question == "Custom question...":
172
  custom_question = nested_col22.text_input("Or ask your own question:", key=f'custom_question_{image_key}')
173
  question = custom_question
174
+
175
  self.process_question(image_key, question, image_data, nested_col22)
176
+
177
  qa_history = image_data.get('qa_history', [])
178
  for num, (q, a, p) in enumerate(qa_history):
179
+ nested_col22.text(f"Q{num + 1}: {q}\nA{num + 1}: {a}\nPrompt Length: {p}\n")
 
180
 
181
  def process_question(self, image_key: str, question: str, image_data: Dict, nested_col22: DeltaGenerator) -> None:
182
  """
183
  Processes the user's question, generates an answer, and updates the question-answer history.
184
+ This method checks if the question is new or if settings have changed, and if so, generates an answer using the
185
+ KBVQA model.
186
  It then updates the question-answer history for the image.
187
 
188
  Args:
 
190
  question (str): The question asked by the user.
191
  image_data (Dict): Data associated with the image.
192
  nested_col22 (DeltaGenerator): The column where the answer will be displayed.
193
+
194
  Returns:
195
  None
196
  """
197
+
198
  qa_history = image_data.get('qa_history', [])
199
+ if question and (
200
+ question not in [q for q, _, _ in qa_history] or self.settings_changed or self.confidance_change):
201
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
202
+ answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'],
203
+ question)
204
  self.add_to_qa_history(image_key, question, answer, prompt_length)
 
205
 
206
  def image_qa_app(self) -> None:
207
  """
208
  Main application interface for image-based question answering.
209
 
210
+ This method orchestrates the display of sample images, handles image uploads, and facilitates the
211
+ question-answering process.
212
+ It iterates through each image in the session state, displaying the image and providing interfaces for image
213
+ analysis and question answering.
214
+
215
  Returns:
216
  None
217
  """
218
+
219
  self.display_sample_images()
220
  self.handle_image_upload()
221
+ # self.display_session_state(self.col1)
222
  with self.col2:
223
  for image_key, image_data in self.get_images_data().items():
224
  with st.container():
225
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
226
  self.display_image_and_analysis(image_key, image_data, nested_col21, nested_col22)
227
  self.handle_question_answering(image_key, image_data, nested_col22)
228
+
 
229
  def run_inference(self) -> None:
230
  """
231
+ Sets up widgets and manages the inference process, including model loading and reloading, based on user
232
+ interactions.
233
 
234
  This method orchestrates the overall flow of the inference process.
235
+
236
  Returns:
237
  None
238
  """
239
 
240
  self.set_up_widgets() # Inherent from the StateManager Class
241
+
242
  load_fine_tuned_model = False
243
  fine_tuned_model_already_loaded = False
244
  reload_detection_model = False
245
  force_reload_full_model = False
246
+
247
  if self.is_model_loaded and self.settings_changed:
248
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
249
  self.update_prev_state()
250
  st.session_state.button_label = (
251
+ "Reload Model" if (self.is_model_loaded and
252
+ st.session_state.kbvqa.detection_model != st.session_state['detection_model']) or
253
+ self.settings_changed())
254
+ else "Load Model"
255
+ )
256
 
 
257
  with self.col1:
258
  if st.session_state.method == "7b-Fine-Tuned Model" or st.session_state.method == "13b-Fine-Tuned Model":
259
  with st.container():
260
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
261
+ if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets,
262
+ disabled=self.is_widget_disabled):
263
  if st.session_state.button_label == "Load Model":
264
  if self.is_model_loaded:
265
  free_gpu_resources()
 
268
  load_fine_tuned_model = True
269
  else:
270
  reload_detection_model = True
271
+ if nested_col12.button("Force Reload", on_click=self.disable_widgets,
272
+ disabled=self.is_widget_disabled):
273
  force_reload_full_model = True
274
  if load_fine_tuned_model:
275
+ t1 = time.time()
276
  free_gpu_resources()
277
  self.load_model()
278
+ st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
279
  st.session_state['loading_in_progress'] = False
280
  elif fine_tuned_model_already_loaded:
281
  free_gpu_resources()
 
287
  st.session_state['loading_in_progress'] = False
288
  elif force_reload_full_model:
289
  free_gpu_resources()
290
+ t1 = time.time()
291
  self.force_reload_model()
292
+ st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
293
  st.session_state['loading_in_progress'] = False
294
  st.session_state['model_loaded'] = True
295
  elif st.session_state.method == "Vision-Language Embeddings Alignment":
296
+ self.col1.warning(
297
+ f'Model using {st.session_state.method} is desgined but requires large scale data and multiple '
298
+ f'high-end GPUs, implementation will be explored in the future.')
299
  if self.is_model_loaded:
300
  free_gpu_resources()
301
  st.session_state['loading_in_progress'] = False
302
+ self.image_qa_app() # this is the main Q/A Application
 
303