|
|
|
import streamlit as st |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
import os |
|
import json |
|
import logging |
|
import tempfile |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
|
|
st.set_page_config( |
|
page_title="Fish Detector", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
sample_images_folder = "./images/sample_images" |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
model_folder = "./models" |
|
st.sidebar.title("π Fish or No Fish Detector") |
|
st.sidebar.markdown(""" |
|
### For more information: |
|
- Contact: Michael.Akridge@NOAA.gov |
|
- Visit the [GitHub repository](https://github.com/MichaelAkridge-NOAA/Fish-or-No-Fish-Detector/) |
|
""") |
|
|
|
st.sidebar.markdown("### Model Links") |
|
st.sidebar.markdown("- [YOLO11 Fish Detector - Grayscale](https://huggingface.co/akridge/yolo11-fish-detector-grayscale)") |
|
st.sidebar.markdown("- [YOLO11 Segment Fish - Grayscale](https://huggingface.co/akridge/yolo11-segment-fish-grayscale)") |
|
model_name = st.sidebar.selectbox("Select a YOLO model", os.listdir(model_folder)) |
|
model_path = os.path.join(model_folder, model_name) |
|
if not os.path.exists(model_path): |
|
st.error(f"Model file not found at {model_path}. Please check your setup.") |
|
st.stop() |
|
model = YOLO(model_path) |
|
|
|
|
|
st.sidebar.header("Model Parameters") |
|
confidence = st.sidebar.slider("Detection Confidence Threshold", 0.0, 1.0, 0.35) |
|
final_confidence = st.sidebar.slider("Final Yes/No Confidence Threshold", 0.0, 1.0, 0.5) |
|
|
|
|
|
st.title("π Fish or No Fish Detector (grayscale)") |
|
st.write(""" |
|
Is there a fish π or not? Upload one or more grayscale images to detect fish. Using a trained [Ultralytics YOLO11 Model](https://github.com/ultralytics/ultralytics) for its object detection capabilities. |
|
|
|
""") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.custom-file-uploader { |
|
display: flex; |
|
align-items: center; |
|
margin-top: -10px; /* Adjust to move button closer */ |
|
justify-content: flex-start; |
|
} |
|
.css-1cpxqw2 { |
|
flex-grow: 1; /* Let file uploader take remaining space */ |
|
} |
|
.sample-button { |
|
font-size: 14px; |
|
padding: 8px; |
|
background-color: #007BFF; |
|
color: white; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
margin-left: 10px; |
|
height: 38px; /* Ensure button matches uploader height */ |
|
} |
|
.sample-button:hover { |
|
background-color: #0056b3; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.stButton>button, .stDownloadButton>button { |
|
width: 100%; |
|
padding: 10px; |
|
border-radius: 5px; |
|
font-size: 18px; |
|
font-weight: bold; |
|
background-color: #007BFF; |
|
color: white; |
|
border: none; |
|
cursor: pointer; |
|
} |
|
.stButton>button:hover, .stDownloadButton>button:hover { |
|
background-color: #0056b3; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
def load_sample_images(): |
|
return [os.path.join(sample_images_folder, img) for img in os.listdir(sample_images_folder) if img.lower().endswith(('png', 'jpg', 'jpeg'))] |
|
|
|
|
|
def run(image_path): |
|
results = model.predict(image_path, conf=confidence) |
|
boxes = [] |
|
fish_count = 0 |
|
confidences = [] |
|
|
|
for box in results[0].boxes: |
|
x1, y1, x2, y2 = box.xyxy[0].tolist() |
|
conf = box.conf[0].item() |
|
class_id = int(box.cls[0].item()) |
|
class_label = model.names[class_id].lower() |
|
|
|
if class_label == "fish" and conf > confidence: |
|
fish_count += 1 |
|
confidences.append(conf) |
|
|
|
boxes.append({"x1": x1, "y1": y1, "x2": x2, "y2": y2, "confidence": conf, "class_id": class_id, "class_label": class_label}) |
|
|
|
return results[0].plot()[:, :, ::-1], {"fish_count": fish_count, "confidences": confidences} |
|
|
|
|
|
|
|
def process_images(uploaded_files): |
|
all_detections = [] |
|
result_images = [] |
|
summary_data = [] |
|
confidences = [] |
|
temp_dir = tempfile.gettempdir() |
|
|
|
for uploaded_file in uploaded_files: |
|
if isinstance(uploaded_file, str): |
|
image_path = uploaded_file |
|
image = Image.open(image_path) |
|
else: |
|
image = Image.open(uploaded_file) |
|
image_path = os.path.join(temp_dir, f"{uploaded_file.name}") |
|
image.save(image_path) |
|
|
|
st.write(f"Detecting in {os.path.basename(image_path)}...") |
|
with st.spinner('Running detection...'): |
|
result_image, detection_metadata = run(image_path) |
|
|
|
if result_image is not None: |
|
result_images.append((result_image, os.path.basename(image_path))) |
|
all_detections.append(detection_metadata) |
|
|
|
summary_data.append({ |
|
"image_name": os.path.basename(image_path), |
|
"fish_detected": detection_metadata["fish_count"] > 0, |
|
"fish_count": detection_metadata["fish_count"] |
|
}) |
|
|
|
confidences.extend(detection_metadata["confidences"]) |
|
|
|
|
|
fish_detected = detection_metadata['fish_count'] > 0 |
|
fish_status = f"<b><span style='color: green; font-size: 24px;'>YES</span></b> π" if fish_detected else f"<b><span style='color: red; font-size: 24px;'>NO</span></b>" |
|
|
|
st.markdown(f"**Summary for {os.path.basename(image_path)}:** Fish detected: {fish_status}", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(image, caption=f"Uploaded Image - {os.path.basename(image_path)}", use_column_width=True) |
|
with col2: |
|
st.image(result_image, caption=f"Detection Results - {os.path.basename(image_path)}", use_column_width=True) |
|
|
|
st.success(f"Detection completed for {os.path.basename(image_path)} successfully! π") |
|
|
|
else: |
|
st.warning(f"No marine ecosystems detected in {os.path.basename(image_path)}.") |
|
|
|
st.session_state["all_detections"] = all_detections |
|
return summary_data, confidences |
|
|
|
|
|
|
|
def display_summary(summary_data, confidences): |
|
if summary_data: |
|
df = pd.DataFrame(summary_data) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("Summary of Detections") |
|
st.table(df[["image_name", "fish_count"]]) |
|
|
|
with col2: |
|
st.subheader("Fish Detection Confidence Levels") |
|
fig, ax = plt.subplots() |
|
confidence_index = 0 |
|
|
|
for i, row in df.iterrows(): |
|
num_confidences_for_image = len([c for c in confidences[confidence_index:confidence_index + row["fish_count"]]]) |
|
|
|
for j in range(num_confidences_for_image): |
|
if confidence_index < len(confidences): |
|
ax.scatter(confidence_index, confidences[confidence_index], c='blue') |
|
ax.text(confidence_index, confidences[confidence_index], row['image_name'], |
|
fontsize=10, ha='center', va='bottom', rotation=0) |
|
confidence_index += 1 |
|
|
|
ax.axhline(final_confidence, color='red', linestyle='--', label=f'Final Threshold ({final_confidence})') |
|
ax.set_xlabel('Detections') |
|
ax.set_ylabel('Confidence Level') |
|
ax.legend(loc='lower left') |
|
st.pyplot(fig) |
|
|
|
if st.session_state.get("all_detections"): |
|
json_data = json.dumps(st.session_state["all_detections"], indent=4) |
|
st.download_button( |
|
label="Download Results as JSON & Reset", |
|
data=json_data, |
|
file_name="all_detections.json", |
|
mime="application/json", |
|
key="download_json_bottom" |
|
) |
|
|
|
|
|
st.markdown('<div class="custom-file-uploader">', unsafe_allow_html=True) |
|
uploaded_files = st.file_uploader("Choose image(s)...", type=["png", "jpg", "jpeg"], accept_multiple_files=True) |
|
|
|
|
|
if not uploaded_files and not st.session_state.get('use_sample_images', False): |
|
use_sample_images = st.button("Or Auto Run Using Sample Images", key="sample_button") |
|
else: |
|
use_sample_images = None |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
if use_sample_images: |
|
sample_images = load_sample_images() |
|
st.session_state['use_sample_images'] = True |
|
for sample_image in sample_images: |
|
st.session_state.setdefault('uploaded_files', []).append(sample_image) |
|
st.session_state['run_automatically'] = True |
|
|
|
|
|
if uploaded_files or st.session_state.get('uploaded_files'): |
|
col1, col2, col3 = st.columns([1, 1, 1], gap="small") |
|
|
|
if not st.session_state.get('use_sample_images', False): |
|
with col1: |
|
run_button = st.button("Click to Run", key="run_button") |
|
else: |
|
run_button = None |
|
|
|
|
|
clear_button = None |
|
|
|
|
|
with col2: |
|
if not st.session_state.get('processing', False): |
|
clear_button = st.button("Clear Results", key="clear_button") |
|
|
|
|
|
if run_button or st.session_state.get('run_automatically'): |
|
st.session_state['processing'] = True |
|
summary_data, confidences = process_images(uploaded_files or st.session_state['uploaded_files']) |
|
display_summary(summary_data, confidences) |
|
st.session_state['processing'] = False |
|
st.session_state['run_automatically'] = False |
|
st.session_state['use_sample_images'] = False |
|
|
|
|
|
if clear_button: |
|
st.session_state.clear() |
|
|
|
if st.session_state.get("all_detections"): |
|
with col3: |
|
json_data = json.dumps(st.session_state["all_detections"], indent=4) |
|
st.download_button( |
|
label="Download Results as JSON & Reset", |
|
data=json_data, |
|
file_name="all_detections.json", |
|
mime="application/json", |
|
key="download_json" |
|
) |
|
|