akridge commited on
Commit
5a3e8f4
β€’
1 Parent(s): 5fa482e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -281
app.py CHANGED
@@ -1,281 +1,281 @@
1
- # Import required modules
2
- import streamlit as st
3
- from ultralytics import YOLO
4
- from PIL import Image
5
- import os
6
- import json
7
- import logging
8
- import tempfile
9
- import pandas as pd
10
- import matplotlib.pyplot as plt
11
-
12
- st.set_page_config(
13
- page_title="Fish Detector",
14
- page_icon="🐟",
15
- layout="wide"
16
- )
17
- sample_images_folder = "./images/sample_images"
18
- logging.basicConfig(level=logging.INFO)
19
-
20
- # Model loading
21
- model_folder = "./models"
22
- st.sidebar.title("🐟 Fish or No Fish Detector")
23
- st.sidebar.markdown("""
24
- ### For more information:
25
- - Contact: Michael.Akridge@NOAA.gov
26
- - Visit the [GitHub repository](https://github.com/MichaelAkridge-NOAA/Fish-or-No-Fish-Detector/)
27
- """)
28
- # Display model links
29
- st.sidebar.markdown("### Model Links")
30
- st.sidebar.markdown("- [YOLO11 Fish Detector - Grayscale](https://huggingface.co/akridge/yolo11-fish-detector-grayscale)")
31
- st.sidebar.markdown("- [YOLO11 Segment Fish - Grayscale](https://huggingface.co/akridge/yolo11-segment-fish-grayscale)")
32
- model_name = st.sidebar.selectbox("Select a YOLO model", os.listdir(model_folder))
33
- model_path = os.path.join(model_folder, model_name)
34
- if not os.path.exists(model_path):
35
- st.error(f"Model file not found at {model_path}. Please check your setup.")
36
- st.stop()
37
- model = YOLO(model_path)
38
-
39
- # Sidebar configuration
40
- st.sidebar.header("Model Parameters")
41
- confidence = st.sidebar.slider("Detection Confidence Threshold", 0.0, 1.0, 0.35)
42
- final_confidence = st.sidebar.slider("Final Yes/No Confidence Threshold", 0.0, 1.0, 0.5)
43
-
44
- # Title and description
45
- st.title("🐟 Fish or No Fish Detector")
46
- st.write("""
47
- Is there a fish 🐟 or not? Upload one or more images to detect fish. Using a trained [Ultralytics YOLO11 Model](https://github.com/ultralytics/ultralytics) for its object detection capabilities.
48
-
49
- """)
50
-
51
- # Custom CSS for button and uploader alignment
52
- st.markdown("""
53
- <style>
54
- .custom-file-uploader {
55
- display: flex;
56
- align-items: center;
57
- margin-top: -10px; /* Adjust to move button closer */
58
- justify-content: flex-start;
59
- }
60
- .css-1cpxqw2 {
61
- flex-grow: 1; /* Let file uploader take remaining space */
62
- }
63
- .sample-button {
64
- font-size: 14px;
65
- padding: 8px;
66
- background-color: #007BFF;
67
- color: white;
68
- border: none;
69
- border-radius: 5px;
70
- cursor: pointer;
71
- margin-left: 10px;
72
- height: 38px; /* Ensure button matches uploader height */
73
- }
74
- .sample-button:hover {
75
- background-color: #0056b3;
76
- }
77
- </style>
78
- """, unsafe_allow_html=True)
79
-
80
- # Custom CSS for default button styling
81
- st.markdown("""
82
- <style>
83
- .stButton>button, .stDownloadButton>button {
84
- width: 100%;
85
- padding: 10px;
86
- border-radius: 5px;
87
- font-size: 18px;
88
- font-weight: bold;
89
- background-color: #007BFF;
90
- color: white;
91
- border: none;
92
- cursor: pointer;
93
- }
94
- .stButton>button:hover, .stDownloadButton>button:hover {
95
- background-color: #0056b3;
96
- }
97
- </style>
98
- """, unsafe_allow_html=True)
99
- # Load sample images function
100
- def load_sample_images():
101
- return [os.path.join(sample_images_folder, img) for img in os.listdir(sample_images_folder) if img.lower().endswith(('png', 'jpg', 'jpeg'))]
102
-
103
- # Prediction function
104
- def run(image_path):
105
- results = model.predict(image_path, conf=confidence)
106
- boxes = []
107
- fish_count = 0
108
- confidences = []
109
-
110
- for box in results[0].boxes:
111
- x1, y1, x2, y2 = box.xyxy[0].tolist()
112
- conf = box.conf[0].item()
113
- class_id = int(box.cls[0].item())
114
- class_label = model.names[class_id].lower() # Normalize to lowercase
115
-
116
- if class_label == "fish" and conf > confidence:
117
- fish_count += 1
118
- confidences.append(conf)
119
-
120
- boxes.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2, "confidence": conf, "class_id": class_id, "class_label": class_label})
121
-
122
- return results[0].plot()[:, :, ::-1], {"fish_count": fish_count, "confidences": confidences}
123
-
124
- # Process images function with directory creation
125
- # Reusable function to handle multiple image uploads and display results
126
- def process_images(uploaded_files):
127
- all_detections = []
128
- result_images = []
129
- summary_data = []
130
- confidences = []
131
- temp_dir = tempfile.gettempdir()
132
-
133
- for uploaded_file in uploaded_files:
134
- if isinstance(uploaded_file, str): # Check if it's a sample image path
135
- image_path = uploaded_file
136
- image = Image.open(image_path)
137
- else:
138
- image = Image.open(uploaded_file)
139
- image_path = os.path.join(temp_dir, f"{uploaded_file.name}")
140
- image.save(image_path)
141
-
142
- st.write(f"Detecting in {os.path.basename(image_path)}...")
143
- with st.spinner('Running detection...'):
144
- result_image, detection_metadata = run(image_path)
145
-
146
- if result_image is not None:
147
- result_images.append((result_image, os.path.basename(image_path)))
148
- all_detections.append(detection_metadata)
149
-
150
- summary_data.append({
151
- "image_name": os.path.basename(image_path),
152
- "fish_detected": detection_metadata["fish_count"] > 0,
153
- "fish_count": detection_metadata["fish_count"]
154
- })
155
-
156
- confidences.extend(detection_metadata["confidences"])
157
-
158
- # Display fish status
159
- fish_detected = detection_metadata['fish_count'] > 0
160
- fish_status = f"<b><span style='color: green; font-size: 24px;'>YES</span></b> 🐟" if fish_detected else f"<b><span style='color: red; font-size: 24px;'>NO</span></b>"
161
-
162
- st.markdown(f"**Summary for {os.path.basename(image_path)}:** Fish detected: {fish_status}", unsafe_allow_html=True)
163
-
164
- # Display images side by side
165
- col1, col2 = st.columns(2)
166
- with col1:
167
- st.image(image, caption=f"Uploaded Image - {os.path.basename(image_path)}", use_column_width=True)
168
- with col2:
169
- st.image(result_image, caption=f"Detection Results - {os.path.basename(image_path)}", use_column_width=True)
170
-
171
- st.success(f"Detection completed for {os.path.basename(image_path)} successfully! 🐟")
172
-
173
- else:
174
- st.warning(f"No marine ecosystems detected in {os.path.basename(image_path)}.")
175
-
176
- st.session_state["all_detections"] = all_detections
177
- return summary_data, confidences
178
-
179
-
180
- # Function to display a summary table and scatter plot side by side with image labels
181
- def display_summary(summary_data, confidences):
182
- if summary_data:
183
- df = pd.DataFrame(summary_data)
184
-
185
- col1, col2 = st.columns(2)
186
-
187
- with col1:
188
- st.subheader("Summary of Detections")
189
- st.table(df[["image_name", "fish_count"]])
190
-
191
- with col2:
192
- st.subheader("Fish Detection Confidence Levels")
193
- fig, ax = plt.subplots()
194
- confidence_index = 0
195
-
196
- for i, row in df.iterrows():
197
- num_confidences_for_image = len([c for c in confidences[confidence_index:confidence_index + row["fish_count"]]])
198
-
199
- for j in range(num_confidences_for_image):
200
- if confidence_index < len(confidences):
201
- ax.scatter(confidence_index, confidences[confidence_index], c='blue')
202
- ax.text(confidence_index, confidences[confidence_index], row['image_name'],
203
- fontsize=10, ha='center', va='bottom', rotation=0)
204
- confidence_index += 1
205
-
206
- ax.axhline(final_confidence, color='red', linestyle='--', label=f'Final Threshold ({final_confidence})')
207
- ax.set_xlabel('Detections')
208
- ax.set_ylabel('Confidence Level')
209
- ax.legend(loc='lower left')
210
- st.pyplot(fig)
211
-
212
- if st.session_state.get("all_detections"):
213
- json_data = json.dumps(st.session_state["all_detections"], indent=4)
214
- st.download_button(
215
- label="Download Results as JSON & Reset",
216
- data=json_data,
217
- file_name="all_detections.json",
218
- mime="application/json",
219
- key="download_json_bottom"
220
- )
221
-
222
- # Image uploader with multiple file support
223
- st.markdown('<div class="custom-file-uploader">', unsafe_allow_html=True)
224
- uploaded_files = st.file_uploader("Choose image(s)...", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
225
-
226
- # Check if files are uploaded, hide the "Auto Run with Sample Images" button if they are
227
- if not uploaded_files and not st.session_state.get('use_sample_images', False):
228
- use_sample_images = st.button("Or Auto Run Using Sample Images", key="sample_button")
229
- else:
230
- use_sample_images = None
231
- st.markdown('</div>', unsafe_allow_html=True)
232
-
233
- # Add the functionality for the "Try it with Sample Images" button
234
- if use_sample_images:
235
- sample_images = load_sample_images()
236
- st.session_state['use_sample_images'] = True
237
- for sample_image in sample_images:
238
- st.session_state.setdefault('uploaded_files', []).append(sample_image)
239
- st.session_state['run_automatically'] = True
240
-
241
- # Display the Run, Clear, and Download buttons with enhanced styling
242
- if uploaded_files or st.session_state.get('uploaded_files'):
243
- col1, col2, col3 = st.columns([1, 1, 1], gap="small")
244
-
245
- if not st.session_state.get('use_sample_images', False):
246
- with col1:
247
- run_button = st.button("Click to Run", key="run_button")
248
- else:
249
- run_button = None
250
-
251
- # Initialize clear_button to None to avoid NameError
252
- clear_button = None
253
-
254
- # Conditionally hide the "Clear Results" button while processing
255
- with col2:
256
- if not st.session_state.get('processing', False):
257
- clear_button = st.button("Clear Results", key="clear_button")
258
-
259
- # Run automatically if triggered by the sample images button or the run button
260
- if run_button or st.session_state.get('run_automatically'):
261
- st.session_state['processing'] = True # Set the processing flag
262
- summary_data, confidences = process_images(uploaded_files or st.session_state['uploaded_files'])
263
- display_summary(summary_data, confidences)
264
- st.session_state['processing'] = False # Reset the processing flag after processing is done
265
- st.session_state['run_automatically'] = False
266
- st.session_state['use_sample_images'] = False
267
-
268
- # Now this check will work, even if clear_button is not defined earlier
269
- if clear_button:
270
- st.session_state.clear()
271
-
272
- if st.session_state.get("all_detections"):
273
- with col3:
274
- json_data = json.dumps(st.session_state["all_detections"], indent=4)
275
- st.download_button(
276
- label="Download Results as JSON & Reset",
277
- data=json_data,
278
- file_name="all_detections.json",
279
- mime="application/json",
280
- key="download_json"
281
- )
 
1
+ # Import required modules
2
+ import streamlit as st
3
+ from ultralytics import YOLO
4
+ from PIL import Image
5
+ import os
6
+ import json
7
+ import logging
8
+ import tempfile
9
+ import pandas as pd
10
+ import matplotlib.pyplot as plt
11
+
12
+ st.set_page_config(
13
+ page_title="Fish Detector",
14
+ page_icon="🐟",
15
+ layout="wide"
16
+ )
17
+ sample_images_folder = "./images/sample_images"
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ # Model loading
21
+ model_folder = "./models"
22
+ st.sidebar.title("🐟 Fish or No Fish Detector")
23
+ st.sidebar.markdown("""
24
+ ### For more information:
25
+ - Contact: Michael.Akridge@NOAA.gov
26
+ - Visit the [GitHub repository](https://github.com/MichaelAkridge-NOAA/Fish-or-No-Fish-Detector/)
27
+ """)
28
+ # Display model links
29
+ st.sidebar.markdown("### Model Links")
30
+ st.sidebar.markdown("- [YOLO11 Fish Detector - Grayscale](https://huggingface.co/akridge/yolo11-fish-detector-grayscale)")
31
+ st.sidebar.markdown("- [YOLO11 Segment Fish - Grayscale](https://huggingface.co/akridge/yolo11-segment-fish-grayscale)")
32
+ model_name = st.sidebar.selectbox("Select a YOLO model", os.listdir(model_folder))
33
+ model_path = os.path.join(model_folder, model_name)
34
+ if not os.path.exists(model_path):
35
+ st.error(f"Model file not found at {model_path}. Please check your setup.")
36
+ st.stop()
37
+ model = YOLO(model_path)
38
+
39
+ # Sidebar configuration
40
+ st.sidebar.header("Model Parameters")
41
+ confidence = st.sidebar.slider("Detection Confidence Threshold", 0.0, 1.0, 0.35)
42
+ final_confidence = st.sidebar.slider("Final Yes/No Confidence Threshold", 0.0, 1.0, 0.5)
43
+
44
+ # Title and description
45
+ st.title("🐟 Fish or No Fish Detector (grayscale)")
46
+ st.write("""
47
+ Is there a fish 🐟 or not? Upload one or more grayscale images to detect fish. Using a trained [Ultralytics YOLO11 Model](https://github.com/ultralytics/ultralytics) for its object detection capabilities.
48
+
49
+ """)
50
+
51
+ # Custom CSS for button and uploader alignment
52
+ st.markdown("""
53
+ <style>
54
+ .custom-file-uploader {
55
+ display: flex;
56
+ align-items: center;
57
+ margin-top: -10px; /* Adjust to move button closer */
58
+ justify-content: flex-start;
59
+ }
60
+ .css-1cpxqw2 {
61
+ flex-grow: 1; /* Let file uploader take remaining space */
62
+ }
63
+ .sample-button {
64
+ font-size: 14px;
65
+ padding: 8px;
66
+ background-color: #007BFF;
67
+ color: white;
68
+ border: none;
69
+ border-radius: 5px;
70
+ cursor: pointer;
71
+ margin-left: 10px;
72
+ height: 38px; /* Ensure button matches uploader height */
73
+ }
74
+ .sample-button:hover {
75
+ background-color: #0056b3;
76
+ }
77
+ </style>
78
+ """, unsafe_allow_html=True)
79
+
80
+ # Custom CSS for default button styling
81
+ st.markdown("""
82
+ <style>
83
+ .stButton>button, .stDownloadButton>button {
84
+ width: 100%;
85
+ padding: 10px;
86
+ border-radius: 5px;
87
+ font-size: 18px;
88
+ font-weight: bold;
89
+ background-color: #007BFF;
90
+ color: white;
91
+ border: none;
92
+ cursor: pointer;
93
+ }
94
+ .stButton>button:hover, .stDownloadButton>button:hover {
95
+ background-color: #0056b3;
96
+ }
97
+ </style>
98
+ """, unsafe_allow_html=True)
99
+ # Load sample images function
100
+ def load_sample_images():
101
+ return [os.path.join(sample_images_folder, img) for img in os.listdir(sample_images_folder) if img.lower().endswith(('png', 'jpg', 'jpeg'))]
102
+
103
+ # Prediction function
104
+ def run(image_path):
105
+ results = model.predict(image_path, conf=confidence)
106
+ boxes = []
107
+ fish_count = 0
108
+ confidences = []
109
+
110
+ for box in results[0].boxes:
111
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
112
+ conf = box.conf[0].item()
113
+ class_id = int(box.cls[0].item())
114
+ class_label = model.names[class_id].lower() # Normalize to lowercase
115
+
116
+ if class_label == "fish" and conf > confidence:
117
+ fish_count += 1
118
+ confidences.append(conf)
119
+
120
+ boxes.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2, "confidence": conf, "class_id": class_id, "class_label": class_label})
121
+
122
+ return results[0].plot()[:, :, ::-1], {"fish_count": fish_count, "confidences": confidences}
123
+
124
+ # Process images function with directory creation
125
+ # Reusable function to handle multiple image uploads and display results
126
+ def process_images(uploaded_files):
127
+ all_detections = []
128
+ result_images = []
129
+ summary_data = []
130
+ confidences = []
131
+ temp_dir = tempfile.gettempdir()
132
+
133
+ for uploaded_file in uploaded_files:
134
+ if isinstance(uploaded_file, str): # Check if it's a sample image path
135
+ image_path = uploaded_file
136
+ image = Image.open(image_path)
137
+ else:
138
+ image = Image.open(uploaded_file)
139
+ image_path = os.path.join(temp_dir, f"{uploaded_file.name}")
140
+ image.save(image_path)
141
+
142
+ st.write(f"Detecting in {os.path.basename(image_path)}...")
143
+ with st.spinner('Running detection...'):
144
+ result_image, detection_metadata = run(image_path)
145
+
146
+ if result_image is not None:
147
+ result_images.append((result_image, os.path.basename(image_path)))
148
+ all_detections.append(detection_metadata)
149
+
150
+ summary_data.append({
151
+ "image_name": os.path.basename(image_path),
152
+ "fish_detected": detection_metadata["fish_count"] > 0,
153
+ "fish_count": detection_metadata["fish_count"]
154
+ })
155
+
156
+ confidences.extend(detection_metadata["confidences"])
157
+
158
+ # Display fish status
159
+ fish_detected = detection_metadata['fish_count'] > 0
160
+ fish_status = f"<b><span style='color: green; font-size: 24px;'>YES</span></b> 🐟" if fish_detected else f"<b><span style='color: red; font-size: 24px;'>NO</span></b>"
161
+
162
+ st.markdown(f"**Summary for {os.path.basename(image_path)}:** Fish detected: {fish_status}", unsafe_allow_html=True)
163
+
164
+ # Display images side by side
165
+ col1, col2 = st.columns(2)
166
+ with col1:
167
+ st.image(image, caption=f"Uploaded Image - {os.path.basename(image_path)}", use_column_width=True)
168
+ with col2:
169
+ st.image(result_image, caption=f"Detection Results - {os.path.basename(image_path)}", use_column_width=True)
170
+
171
+ st.success(f"Detection completed for {os.path.basename(image_path)} successfully! 🐟")
172
+
173
+ else:
174
+ st.warning(f"No marine ecosystems detected in {os.path.basename(image_path)}.")
175
+
176
+ st.session_state["all_detections"] = all_detections
177
+ return summary_data, confidences
178
+
179
+
180
+ # Function to display a summary table and scatter plot side by side with image labels
181
+ def display_summary(summary_data, confidences):
182
+ if summary_data:
183
+ df = pd.DataFrame(summary_data)
184
+
185
+ col1, col2 = st.columns(2)
186
+
187
+ with col1:
188
+ st.subheader("Summary of Detections")
189
+ st.table(df[["image_name", "fish_count"]])
190
+
191
+ with col2:
192
+ st.subheader("Fish Detection Confidence Levels")
193
+ fig, ax = plt.subplots()
194
+ confidence_index = 0
195
+
196
+ for i, row in df.iterrows():
197
+ num_confidences_for_image = len([c for c in confidences[confidence_index:confidence_index + row["fish_count"]]])
198
+
199
+ for j in range(num_confidences_for_image):
200
+ if confidence_index < len(confidences):
201
+ ax.scatter(confidence_index, confidences[confidence_index], c='blue')
202
+ ax.text(confidence_index, confidences[confidence_index], row['image_name'],
203
+ fontsize=10, ha='center', va='bottom', rotation=0)
204
+ confidence_index += 1
205
+
206
+ ax.axhline(final_confidence, color='red', linestyle='--', label=f'Final Threshold ({final_confidence})')
207
+ ax.set_xlabel('Detections')
208
+ ax.set_ylabel('Confidence Level')
209
+ ax.legend(loc='lower left')
210
+ st.pyplot(fig)
211
+
212
+ if st.session_state.get("all_detections"):
213
+ json_data = json.dumps(st.session_state["all_detections"], indent=4)
214
+ st.download_button(
215
+ label="Download Results as JSON & Reset",
216
+ data=json_data,
217
+ file_name="all_detections.json",
218
+ mime="application/json",
219
+ key="download_json_bottom"
220
+ )
221
+
222
+ # Image uploader with multiple file support
223
+ st.markdown('<div class="custom-file-uploader">', unsafe_allow_html=True)
224
+ uploaded_files = st.file_uploader("Choose image(s)...", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
225
+
226
+ # Check if files are uploaded, hide the "Auto Run with Sample Images" button if they are
227
+ if not uploaded_files and not st.session_state.get('use_sample_images', False):
228
+ use_sample_images = st.button("Or Auto Run Using Sample Images", key="sample_button")
229
+ else:
230
+ use_sample_images = None
231
+ st.markdown('</div>', unsafe_allow_html=True)
232
+
233
+ # Add the functionality for the "Try it with Sample Images" button
234
+ if use_sample_images:
235
+ sample_images = load_sample_images()
236
+ st.session_state['use_sample_images'] = True
237
+ for sample_image in sample_images:
238
+ st.session_state.setdefault('uploaded_files', []).append(sample_image)
239
+ st.session_state['run_automatically'] = True
240
+
241
+ # Display the Run, Clear, and Download buttons with enhanced styling
242
+ if uploaded_files or st.session_state.get('uploaded_files'):
243
+ col1, col2, col3 = st.columns([1, 1, 1], gap="small")
244
+
245
+ if not st.session_state.get('use_sample_images', False):
246
+ with col1:
247
+ run_button = st.button("Click to Run", key="run_button")
248
+ else:
249
+ run_button = None
250
+
251
+ # Initialize clear_button to None to avoid NameError
252
+ clear_button = None
253
+
254
+ # Conditionally hide the "Clear Results" button while processing
255
+ with col2:
256
+ if not st.session_state.get('processing', False):
257
+ clear_button = st.button("Clear Results", key="clear_button")
258
+
259
+ # Run automatically if triggered by the sample images button or the run button
260
+ if run_button or st.session_state.get('run_automatically'):
261
+ st.session_state['processing'] = True # Set the processing flag
262
+ summary_data, confidences = process_images(uploaded_files or st.session_state['uploaded_files'])
263
+ display_summary(summary_data, confidences)
264
+ st.session_state['processing'] = False # Reset the processing flag after processing is done
265
+ st.session_state['run_automatically'] = False
266
+ st.session_state['use_sample_images'] = False
267
+
268
+ # Now this check will work, even if clear_button is not defined earlier
269
+ if clear_button:
270
+ st.session_state.clear()
271
+
272
+ if st.session_state.get("all_detections"):
273
+ with col3:
274
+ json_data = json.dumps(st.session_state["all_detections"], indent=4)
275
+ st.download_button(
276
+ label="Download Results as JSON & Reset",
277
+ data=json_data,
278
+ file_name="all_detections.json",
279
+ mime="application/json",
280
+ key="download_json"
281
+ )