m7mdal7aj commited on
Commit
4a6ec58
1 Parent(s): 4d96ac5

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +19 -25
my_model/tabs/run_inference.py CHANGED
@@ -61,7 +61,7 @@ class InferenceRunner(StateManager):
61
 
62
 
63
  # Display sample images as clickable thumbnails
64
- st.write("D")
65
  self.col1.write("Choose from sample images:")
66
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
67
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
@@ -81,40 +81,37 @@ class InferenceRunner(StateManager):
81
  self.display_session_state()
82
  with self.col2:
83
  for image_key, image_data in self.get_images_data().items():
84
- st.write("E")
85
  with st.container():
86
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
87
  image_for_display = self.resize_image(image_data['image'], 600)
88
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
89
- st.write(image_data['analysis_done'] , self.settings_changed , self.confidance_change)
90
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change: # if not done analysis before or even done but settings changed, then we need to analyze again
91
- st.write("F")
92
  nested_col22.text("Please click 'Analyze Image'..")
93
  free_gpu_resources()
94
  with nested_col22:
95
- st.write("G")
96
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}' # unique key for each click
97
- st.write(analyze_button_key)
98
- st.write("H")
99
  if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
100
- st.text("AAAAAAAAAAAAAAAAAAAAA")
101
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
102
- st.text("BBBBBBBBBBBBBBBBBBBB")
103
- st.write(detected_objects_str)
104
  self.update_image_data(image_key, caption, detected_objects_str, True)
105
  st.session_state['loading_in_progress'] = False
106
  free_gpu_resources()
107
- st.write("II")
108
 
109
  # Initialize qa_history for each image
110
  qa_history = image_data.get('qa_history', [])
111
- st.write("J")
112
  if image_data['analysis_done']:
113
- st.write("K")
114
  free_gpu_resources()
115
  if self.confidance_change:
116
- st.write("L")
117
- nested_col22.warning("Confidence level changed, please click analyze again.")
118
 
119
  st.session_state['loading_in_progress'] = False
120
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
@@ -130,15 +127,12 @@ class InferenceRunner(StateManager):
130
  # Use the selected sample question or the custom question
131
  question = custom_question if selected_question == "Custom question..." else selected_question
132
 
133
- # if not question:
134
- # nested_col22.warning("Please select or enter a question.")
135
- # else:
136
  if question in [q for q, _, _ in qa_history] and not self.settings_changed and not self.confidance_change:
137
  nested_col22.warning("This question has already been answered.")
138
  else:
139
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
140
  free_gpu_resources()
141
- st.write("M")
142
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
143
  st.session_state['loading_in_progress'] = False
144
  self.add_to_qa_history(image_key, question, answer, prompt_length)
@@ -147,7 +141,7 @@ class InferenceRunner(StateManager):
147
  for num, (q, a, p) in enumerate(qa_history):
148
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
149
  free_gpu_resources()
150
- st.write("N")
151
 
152
 
153
 
@@ -159,20 +153,20 @@ class InferenceRunner(StateManager):
159
  """
160
 
161
  self.set_up_widgets()
162
- st.write("A")
163
  load_fine_tuned_model = False
164
  fine_tuned_model_already_loaded = False
165
  reload_detection_model = False
166
  force_reload_full_model = False
167
 
168
- #st.session_state['settings_changed'] = self.has_state_changed()
169
  if self.is_model_loaded and self.settings_changed:
170
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
171
  self.update_prev_state()
172
- st.write("B")
173
 
174
- st.session_state.button_label = "Reload Model" if self.is_model_loaded and self.settings_changed else "Load Model"
175
- st.write("C")
176
  with self.col1:
177
  if st.session_state.method == "Fine-Tuned Model":
178
  with st.container():
 
61
 
62
 
63
  # Display sample images as clickable thumbnails
64
+
65
  self.col1.write("Choose from sample images:")
66
  cols = self.col1.columns(len(config.SAMPLE_IMAGES))
67
  for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
 
81
  self.display_session_state()
82
  with self.col2:
83
  for image_key, image_data in self.get_images_data().items():
84
+
85
  with st.container():
86
  nested_col21, nested_col22 = st.columns([0.65, 0.35])
87
  image_for_display = self.resize_image(image_data['image'], 600)
88
  nested_col21.image(image_for_display, caption=f'Uploaded Image: {image_key[-11:]}')
89
+ nested_col21.write(image_data['analysis_done'] , self.settings_changed , self.confidance_change)
90
  if not image_data['analysis_done'] or self.settings_changed or self.confidance_change: # if not done analysis before or even done but settings changed, then we need to analyze again
91
+
92
  nested_col22.text("Please click 'Analyze Image'..")
93
  free_gpu_resources()
94
  with nested_col22:
95
+
96
  analyze_button_key = f'analyze_{image_key}_{st.session_state.detection_model}_{st.session_state.confidence_level}' # unique key for each click
97
+
 
98
  if st.button('Analyze Image', key=analyze_button_key, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
 
99
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'])
100
+
 
101
  self.update_image_data(image_key, caption, detected_objects_str, True)
102
  st.session_state['loading_in_progress'] = False
103
  free_gpu_resources()
104
+
105
 
106
  # Initialize qa_history for each image
107
  qa_history = image_data.get('qa_history', [])
108
+
109
  if image_data['analysis_done']:
110
+
111
  free_gpu_resources()
112
  if self.confidance_change:
113
+
114
+ nested_col22.warning("If you change the Confidence level, please click analyze again.")
115
 
116
  st.session_state['loading_in_progress'] = False
117
  sample_questions = config.SAMPLE_QUESTIONS.get(image_key, [])
 
127
  # Use the selected sample question or the custom question
128
  question = custom_question if selected_question == "Custom question..." else selected_question
129
 
 
 
 
130
  if question in [q for q, _, _ in qa_history] and not self.settings_changed and not self.confidance_change:
131
  nested_col22.warning("This question has already been answered.")
132
  else:
133
  if nested_col22.button('Get Answer', key=f'answer_{image_key}', disabled=self.is_widget_disabled):
134
  free_gpu_resources()
135
+
136
  answer, prompt_length = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question)
137
  st.session_state['loading_in_progress'] = False
138
  self.add_to_qa_history(image_key, question, answer, prompt_length)
 
141
  for num, (q, a, p) in enumerate(qa_history):
142
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\nPrompt Length: {p}\n")
143
  free_gpu_resources()
144
+
145
 
146
 
147
 
 
153
  """
154
 
155
  self.set_up_widgets()
156
+
157
  load_fine_tuned_model = False
158
  fine_tuned_model_already_loaded = False
159
  reload_detection_model = False
160
  force_reload_full_model = False
161
 
162
+
163
  if self.is_model_loaded and self.settings_changed:
164
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
165
  self.update_prev_state()
166
+
167
 
168
+ st.session_state.button_label = "Reload Model" if self.is_model_loaded and st.session_state.kvbqa.detection_model != st.session_state['detection_model'] else "Load Model"
169
+
170
  with self.col1:
171
  if st.session_state.method == "Fine-Tuned Model":
172
  with st.container():