File size: 5,736 Bytes
048c0b7 559f812 048c0b7 559f812 048c0b7 21e5166 559f812 21e5166 559f812 21e5166 048c0b7 559f812 048c0b7 559f812 048c0b7 12167ec 048c0b7 cd423db 048c0b7 f908d63 048c0b7 559f812 048c0b7 559f812 048c0b7 559f812 048c0b7 559f812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
class CNNKAN(nn.Module):
def __init__(self):
super(CNNKAN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.pool4 = nn.MaxPool2d(2)
self.dropout = nn.Dropout(0.5)
self.kan1 = KANLinear(256 * 12 * 12, 512)
self.kan2 = KANLinear(512, 1)
def forward(self, x):
x = F.selu(self.bn1(self.conv1(x)))
x = self.pool1(x)
x = F.selu(self.bn2(self.conv2(x)))
x = self.pool2(x)
x = F.selu(self.bn3(self.conv3(x)))
x = self.pool3(x)
x = F.selu(self.bn4(self.conv4(x)))
x = self.pool4(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.kan1(x)
x = self.dropout(x)
x = self.kan2(x)
return x
def load_model(weights_path, device):
model = CNNKAN().to(device)
state_dict = torch.load(weights_path, map_location=device)
# Remove 'module.' prefix from keys
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[len('module.'):]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()
return model
class CustomImageLoadingError(Exception):
"""Custom exception for image loading errors"""
pass
def load_image_from_url(url):
try:
logging.info(f"Loading image from URL: {url}")
# Check the file extension
valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
file_extension = os.path.splitext(url)[1][1:].lower()
if file_extension not in valid_extensions:
raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")
response = requests.get(url)
response.raise_for_status() # Check if the request was successful
content_type = response.headers['Content-Type']
logging.info(f"Content-Type: {content_type}")
# Check if the content type is an image
if 'image' not in content_type:
raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")
img = Image.open(BytesIO(response.content)).convert('RGB')
logging.info("Image successfully loaded and converted to RGB")
return img
except requests.HTTPError as e:
logging.error(f"HTTPError while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except UnidentifiedImageError as e:
logging.error(f"UnidentifiedImageError while loading image: {e}")
raise CustomImageLoadingError(f"Cannot identify image file: {e}")
except requests.RequestException as e:
logging.error(f"RequestException while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
except Exception as e:
logging.error(f"Unexpected error while loading image: {e}")
raise CustomImageLoadingError(f"Error loading image from URL: {e}")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((200, 200)),
transforms.ToTensor()
])
return transform(image).unsqueeze(0)
# Streamlit app
st.title("Cat and Dog Classification with CNN-KAN")
st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_weights_CNNKAN.pth', device)
img = None
if uploaded_file is not None:
logging.info("Image uploaded via file uploader")
img = Image.open(uploaded_file).convert('RGB')
elif image_url:
try:
img = load_image_from_url(image_url)
except CustomImageLoadingError as e:
st.sidebar.error(str(e))
except Exception as e:
st.sidebar.error(f"Unexpected error: {e}")
st.sidebar.write("-----")
# Define your information for the footer
name = "Wayan Dadang"
st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
[LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
[GitHub](https://github.com/Wayan123)
[Resume](https://wayan123.github.io/)
© {name} - {2024}
""", unsafe_allow_html=True)
if img is not None:
st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
if st.button('Predict'):
img_tensor = preprocess_image(img).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
st.write(f"Prediction: {prob:.4f}")
if prob < 0.5:
st.write("This image is classified as a Cat.")
else:
st.write("This image is classified as a Dog.")
|