File size: 5,736 Bytes
048c0b7
 
 
559f812
 
048c0b7
 
 
 
 
559f812
 
 
 
 
048c0b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e5166
 
 
559f812
21e5166
 
559f812
21e5166
 
 
 
048c0b7
 
 
559f812
 
 
 
048c0b7
559f812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
048c0b7
 
 
12167ec
048c0b7
 
 
 
 
cd423db
048c0b7
 
 
 
 
 
f908d63
048c0b7
 
 
 
559f812
048c0b7
 
 
 
559f812
 
048c0b7
559f812
048c0b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559f812
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
import streamlit as st
import numpy as np
import requests
from io import BytesIO
from kan_linear import KANLinear
import logging
import os

# Setup logging
logging.basicConfig(level=logging.INFO)

class CNNKAN(nn.Module):
    def __init__(self):
        super(CNNKAN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool4 = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(0.5)
        self.kan1 = KANLinear(256 * 12 * 12, 512)
        self.kan2 = KANLinear(512, 1)

    def forward(self, x):
        x = F.selu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.selu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.selu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = F.selu(self.bn4(self.conv4(x)))
        x = self.pool4(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.kan1(x)
        x = self.dropout(x)
        x = self.kan2(x)
        return x

def load_model(weights_path, device):
    model = CNNKAN().to(device)
    state_dict = torch.load(weights_path, map_location=device)
    
    # Remove 'module.' prefix from keys
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()
    return model

class CustomImageLoadingError(Exception):
    """Custom exception for image loading errors"""
    pass

def load_image_from_url(url):
    try:
        logging.info(f"Loading image from URL: {url}")

        # Check the file extension
        valid_extensions = ['jpg', 'jpeg', 'png', 'webp']
        file_extension = os.path.splitext(url)[1][1:].lower()
        if file_extension not in valid_extensions:
            raise CustomImageLoadingError(f"URL does not point to an image with a valid extension: {file_extension}")

        response = requests.get(url)
        response.raise_for_status()  # Check if the request was successful

        content_type = response.headers['Content-Type']
        logging.info(f"Content-Type: {content_type}")

        # Check if the content type is an image
        if 'image' not in content_type:
            raise CustomImageLoadingError(f"URL does not point to an image: {content_type}")

        img = Image.open(BytesIO(response.content)).convert('RGB')
        logging.info("Image successfully loaded and converted to RGB")
        return img
    except requests.HTTPError as e:
        logging.error(f"HTTPError while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except UnidentifiedImageError as e:
        logging.error(f"UnidentifiedImageError while loading image: {e}")
        raise CustomImageLoadingError(f"Cannot identify image file: {e}")
    except requests.RequestException as e:
        logging.error(f"RequestException while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")
    except Exception as e:
        logging.error(f"Unexpected error while loading image: {e}")
        raise CustomImageLoadingError(f"Error loading image from URL: {e}")

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((200, 200)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0)

# Streamlit app
st.title("Cat and Dog Classification with CNN-KAN")

st.sidebar.title("Upload Images")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "webp"])
image_url = st.sidebar.text_input("Or enter image URL...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model('weights/best_model_weights_CNNKAN.pth', device)

img = None

if uploaded_file is not None:
    logging.info("Image uploaded via file uploader")
    img = Image.open(uploaded_file).convert('RGB')
elif image_url:
    try:
        img = load_image_from_url(image_url)
    except CustomImageLoadingError as e:
        st.sidebar.error(str(e))
    except Exception as e:
        st.sidebar.error(f"Unexpected error: {e}")

st.sidebar.write("-----")

# Define your information for the footer
name = "Wayan Dadang"

st.sidebar.write("Follow me on:")
# Create a footer section with links and copyright information
st.sidebar.markdown(f"""
    [LinkedIn](https://www.linkedin.com/in/wayan-dadang-801757116/)
    [GitHub](https://github.com/Wayan123)
    [Resume](https://wayan123.github.io/)
    © {name} - {2024}
    """, unsafe_allow_html=True)

if img is not None:
    st.image(np.array(img), caption='Uploaded Image.', use_column_width=True)
    if st.button('Predict'):
        img_tensor = preprocess_image(img).to(device)

        with torch.no_grad():
            output = model(img_tensor)
            prob = torch.sigmoid(output).item()

        st.write(f"Prediction: {prob:.4f}")

        if prob < 0.5:
            st.write("This image is classified as a Cat.")
        else:
            st.write("This image is classified as a Dog.")