m7mdal7aj commited on
Commit
b491c60
1 Parent(s): cc2adc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -8,6 +8,7 @@ import scipy
8
  import copy
9
  from PIL import Image
10
  import torch.nn as nn
 
11
  from my_model.object_detection import detect_and_draw_objects
12
  from my_model.captioner.image_captioning import get_caption
13
  from my_model.gen_utilities import free_gpu_resources
@@ -29,13 +30,14 @@ sample_images = ["Files/sample1.jpg", "Files/sample2.jpg", "Files/sample3.jpg",
29
 
30
 
31
  def analyze_image(image, model):
32
- st.text("Cool image, let me analyze it..")
33
- caption = model.get_caption(image)
34
- image_with_boxes, detected_objects_str = model.detect_objects(image)
 
35
  st.text("I am ready, let's talk!")
36
  free_gpu_resources()
37
 
38
- return caption, detected_objects_str
39
 
40
 
41
  def image_qa_app(kbvqa):
@@ -61,8 +63,9 @@ def image_qa_app(kbvqa):
61
  for image_key, image_data in st.session_state['images_data'].items():
62
  st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
63
  if not image_data['analysis_done']:
 
64
  if st.button('Analyze Image', key=f'analyze_{image_key}'):
65
- caption, detected_objects_str = analyze_image(image_data['image'], kbvqa)
66
  image_data['caption'] = caption
67
  image_data['detected_objects_str'] = detected_objects_str
68
  image_data['analysis_done'] = True
@@ -98,7 +101,7 @@ def process_new_image(image_key, image, kbvqa):
98
 
99
  def run_inference():
100
  st.title("Run Inference")
101
- st.write("Please note that this is not a general purpose model, it is specifically trained on OK-VQA dataset and is designed to give short answers to the given questions.")
102
 
103
  method = st.selectbox(
104
  "Choose a method:",
@@ -165,6 +168,16 @@ def run_inference():
165
  else:
166
  st.write('Model is not ready for inference yet')
167
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
  # Main function
 
8
  import copy
9
  from PIL import Image
10
  import torch.nn as nn
11
+ import pandas as pd
12
  from my_model.object_detection import detect_and_draw_objects
13
  from my_model.captioner.image_captioning import get_caption
14
  from my_model.gen_utilities import free_gpu_resources
 
30
 
31
 
32
  def analyze_image(image, model):
33
+
34
+ img = copy.deepcopy(image) # we dont wanna apply changes to the original image
35
+ caption = model.get_caption(img)
36
+ image_with_boxes, detected_objects_str = model.detect_objects(img)
37
  st.text("I am ready, let's talk!")
38
  free_gpu_resources()
39
 
40
+ return caption, detected_objects_str, image_with_boxes
41
 
42
 
43
  def image_qa_app(kbvqa):
 
63
  for image_key, image_data in st.session_state['images_data'].items():
64
  st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
65
  if not image_data['analysis_done']:
66
+ st.text("Cool image, please click 'Analyze Image'..")
67
  if st.button('Analyze Image', key=f'analyze_{image_key}'):
68
+ caption, detected_objects_str, image_with_boxes = analyze_image(image_data['image'], kbvqa) # we can use the image_with_boxes later if we want to show it.
69
  image_data['caption'] = caption
70
  image_data['detected_objects_str'] = detected_objects_str
71
  image_data['analysis_done'] = True
 
101
 
102
  def run_inference():
103
  st.title("Run Inference")
104
+ st.write("Please note that this is not a general purpose model, it is specifically trained on OK-VQA dataset and is designed to give direct and short answers to the given questions.")
105
 
106
  method = st.selectbox(
107
  "Choose a method:",
 
168
  else:
169
  st.write('Model is not ready for inference yet')
170
 
171
+ # Display model settings
172
+ if 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None:
173
+ model_settings = {
174
+ 'Detection Model': st.session_state['model_settings']['detection_model'],
175
+ 'Confidence Level': st.session_state['model_settings']['confidence_level']
176
+ }
177
+ st.write("### Current Model Settings:")
178
+ st.table(pd.DataFrame(model_settings, index=[0]))
179
+
180
+
181
 
182
 
183
  # Main function