Spaces:
Runtime error
Runtime error
File size: 7,009 Bytes
fc24292 0d173bb fc24292 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
# Define the model
class KANVGG16(nn.Module):
def __init__(self, num_classes=1): # For binary classification (cats and dogs)
super(KANVGG16, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(64), # Added Batch Normalization
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(128), # Added Batch Normalization
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(256), # Added Batch Normalization
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(512), # Added Batch Normalization
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(512), # Added Batch Normalization
)
self.classifier = nn.Sequential(
KANLinear(512 * 7 * 7, 2048), # Adjusted for input size 224x224
nn.ReLU(inplace=True),
nn.Dropout(0.5), # Increased Dropout
KANLinear(2048, 2048),
nn.ReLU(inplace=True),
nn.Dropout(0.5), # Increased Dropout
KANLinear(2048, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def load_model(weights_path, device):
model = KANVGG16().to(device)
state_dict = torch.load(weights_path, map_location=device)
# Remove 'module.' prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[len('module.'):]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
class CustomImageLoadingError(Exception):
"""Custom exception for image loading errors"""
pass
def load_image_from_url(url):
try:
logging.info(f"Loading image from URL: {url}")
# Check the file extension
valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
file_extension = os.path.splitext(url)[1][1:].lower()
if file_extension not in valid_extensions:
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")
response = requests.get(url)
response.raise_for_status() # Check if the request was successful
content_type = response.headers['Content-Type']
logging.info(f"Content-Type: {content_type}")
# Check if the content type is an image
if 'image' not in content_type:
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
img = Image.open(BytesIO(response.content)).convert('RGB')
logging.info("Image successfully loaded and converted to RGB")
return img
except requests.HTTPError as e:
logging.error(f"HTTPError while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except UnidentifiedImageError as e:
logging.error(f"UnidentifiedImageError while loading image: {e}")
raise CustomImageLoadingError(f"Cannot identify image file: {e}")
except requests.RequestException as e:
logging.error(f"RequestException while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except Exception as e:
logging.error(f"Unexpected error while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Streamlit app
st.title("Cat and Dog Classification with VGG16-KAN")
st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_VGG16_KAN_97.pth', device)
img = None
if uploaded_file is not None:
logging.info("Image uploaded via file uploader")
img = Image.open(uploaded_file).convert('RGB')
elif image_url:
try:
img = load_image_from_url(image_url)
except CustomImageLoadingError as e:
st.sidebar.error(str(e))
except Exception as e:
st.sidebar.error(f"Unexpected error: {e}")
st.sidebar.write("-----")
# Define your information for the footer
name = "Wayan Dadang"
st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
[GitHub](https://github.com/Wayan123)
[Resume](https://wayan123.github.io/)
© {name} - {2024}
""", unsafe_allow_html=True)
if img is not None:
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
if st.button('Predict'):
img_tensor = preprocess_image(img).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
st.write(f"Prediction: {prob:.4f}")
if prob < 0.5:
st.write("This image is classified as a Cat.")
else:
st.write("This image is classified as a Dog.")
|