Spaces:
Sleeping
Sleeping
darshanjani
commited on
Commit
•
58c979f
1
Parent(s):
09036b9
Upload 17 files
Browse files- Store/examples/airplane.png +0 -0
- Store/examples/bird.webp +0 -0
- Store/examples/car.jpg +0 -0
- Store/examples/cat.jpeg +0 -0
- Store/examples/deer.webp +0 -0
- Store/examples/dog1.jpg +0 -0
- Store/examples/frog1.webp +0 -0
- Store/examples/horse.jpg +0 -0
- Store/examples/shipp.jpg +0 -0
- Store/examples/truck1.jpg +0 -0
- Store/model.pth +3 -0
- Utilities/config.py +65 -0
- Utilities/model.py +305 -0
- Utilities/transforms.py +11 -0
- Utilities/utils.py +77 -0
- Utilities/visualize.py +31 -0
- app.py +108 -0
Store/examples/airplane.png
ADDED
Store/examples/bird.webp
ADDED
Store/examples/car.jpg
ADDED
Store/examples/cat.jpeg
ADDED
Store/examples/deer.webp
ADDED
Store/examples/dog1.jpg
ADDED
Store/examples/frog1.webp
ADDED
Store/examples/horse.jpg
ADDED
Store/examples/shipp.jpg
ADDED
Store/examples/truck1.jpg
ADDED
Store/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbd64f23fadf7bffb54d9f55e39771ebb15e40e3d64660d3972cc650def37d51
|
3 |
+
size 26333951
|
Utilities/config.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
#SEED
|
5 |
+
|
6 |
+
SEED = 1
|
7 |
+
|
8 |
+
#DATASET
|
9 |
+
|
10 |
+
CLASSES = (
|
11 |
+
"Airplane",
|
12 |
+
"Automobile",
|
13 |
+
"Bird",
|
14 |
+
"Cat",
|
15 |
+
"Deer",
|
16 |
+
"Dog",
|
17 |
+
"Frog",
|
18 |
+
"Horse",
|
19 |
+
"Ship",
|
20 |
+
"Truck"
|
21 |
+
)
|
22 |
+
|
23 |
+
SHUFFLE = True
|
24 |
+
DATA_DIR = "../data"
|
25 |
+
NUM_WORKERS = 4
|
26 |
+
PIN_MEMORY = True
|
27 |
+
|
28 |
+
# TRAINING HP
|
29 |
+
|
30 |
+
CRITERION = F.cross_entropy
|
31 |
+
INPUT_SIZE = (3, 32, 32)
|
32 |
+
NUM_CLASSES = 10
|
33 |
+
LEARNING_RATE = 0.001
|
34 |
+
WEIGHT_DECAY = 1e-4
|
35 |
+
BATCH_SIZE = 512
|
36 |
+
NUM_EPOCHS = 24
|
37 |
+
DROPOUT_PERCENTAGE = 0.05
|
38 |
+
LAYER_NORM = "bn"
|
39 |
+
|
40 |
+
# OPTIMIZER & SCHEDULAR
|
41 |
+
|
42 |
+
LRFINDER_END_LR = 0.1
|
43 |
+
LRFINDER_NUM_ITERATIONS = 50
|
44 |
+
LRFINDER_STEP_MODE = "exp"
|
45 |
+
|
46 |
+
OCLR_DIV_FACTOR = 100
|
47 |
+
OCLR_FINAL_DIV_FACTOR = 100
|
48 |
+
OCLR_THREE_PHASE = False
|
49 |
+
OCLR_ANNEAL_STRATEGY = "linear"
|
50 |
+
|
51 |
+
# COMPUTE RELATED
|
52 |
+
|
53 |
+
ACCELERATOR = "cpu"
|
54 |
+
PRECISION = 32
|
55 |
+
|
56 |
+
# STORAGE
|
57 |
+
|
58 |
+
TRAINING_STAT_STORE = "Store/training_stats.csv"
|
59 |
+
MODEL_SAVE_PATH = "Store/model.pth"
|
60 |
+
PRED_STORE_PATH = "Store/pred_store.pth"
|
61 |
+
EXAMPLE_IMG_PATH = "Store/examples/"
|
62 |
+
|
63 |
+
# VISULIZATION
|
64 |
+
|
65 |
+
NORM_CONF_MAT = True
|
Utilities/model.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import seaborn as sns
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.optim as optim
|
10 |
+
import torchmetrics
|
11 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
+
from torch_lr_finder import LRFinder
|
13 |
+
|
14 |
+
from . import config # Custom config file
|
15 |
+
from .visualize import plot_incorrect_preds
|
16 |
+
|
17 |
+
|
18 |
+
class Net(pl.LightningModule):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
num_classes=10,
|
22 |
+
dropout_percentage=0,
|
23 |
+
norm='bn',
|
24 |
+
num_groups=2,
|
25 |
+
criterion=F.cross_entropy,
|
26 |
+
learning_rate=0.001,
|
27 |
+
weight_decay=0.0
|
28 |
+
):
|
29 |
+
super(Net, self).__init__()
|
30 |
+
|
31 |
+
# Define norm
|
32 |
+
if norm == 'bn':
|
33 |
+
self.norm = nn.BatchNorm2d
|
34 |
+
elif norm == 'gn':
|
35 |
+
self.norm = lambda in_dim: nn.GroupNorm(
|
36 |
+
num_groups=num_groups, num_channels=in_dim
|
37 |
+
)
|
38 |
+
elif norm == 'ln':
|
39 |
+
self.norm = lambda in_dim: nn.GroupNorm(
|
40 |
+
num_groups=in_dim, num_channels=in_dim
|
41 |
+
)
|
42 |
+
|
43 |
+
#define loss
|
44 |
+
self.criterion = criterion
|
45 |
+
|
46 |
+
#define metrics
|
47 |
+
self.accuracy = torchmetrics.Accuracy(
|
48 |
+
task='multiclass', num_classes=num_classes
|
49 |
+
)
|
50 |
+
self.confusion_matrix = torchmetrics.ConfusionMatrix(
|
51 |
+
task='multiclass', num_classes=num_classes
|
52 |
+
)
|
53 |
+
|
54 |
+
#define the optimizer hyperparameters
|
55 |
+
self.learning_rate = learning_rate
|
56 |
+
self.weight_decay = weight_decay
|
57 |
+
|
58 |
+
#prediction storage
|
59 |
+
self.pred_store = {
|
60 |
+
"test_preds": torch.tensor([]),
|
61 |
+
"test_labels": torch.tensor([]),
|
62 |
+
"test_incorrect": [] #?
|
63 |
+
}
|
64 |
+
self.log_store = { # not used at all
|
65 |
+
"train_loss_epoch": [],
|
66 |
+
"train_acc_epoch": [],
|
67 |
+
"val_loss_epoch": [],
|
68 |
+
"val_acc_epoch": [],
|
69 |
+
"test_loss_epoch": [], # not used
|
70 |
+
"test_acc_epoch": [], # not used
|
71 |
+
}
|
72 |
+
|
73 |
+
# Define the network architecture
|
74 |
+
self.prep_layer = nn.Sequential(
|
75 |
+
nn.Conv2d(3, 64, kernel_size=3, padding=1), # 32x32x3 | 1 -> 32x32x64 | 3
|
76 |
+
self.norm(64),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.Dropout(dropout_percentage),
|
79 |
+
)
|
80 |
+
|
81 |
+
self.l1 = nn.Sequential(
|
82 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1), # 32x32x128 | 5
|
83 |
+
nn.MaxPool2d(2, 2), # 16x16x128 | 6
|
84 |
+
self.norm(128),
|
85 |
+
nn.ReLU(),
|
86 |
+
nn.Dropout(dropout_percentage),
|
87 |
+
)
|
88 |
+
self.l1res = nn.Sequential(
|
89 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 10
|
90 |
+
self.norm(128),
|
91 |
+
nn.ReLU(),
|
92 |
+
nn.Dropout(dropout_percentage),
|
93 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), # 16x16x128 | 14
|
94 |
+
self.norm(128),
|
95 |
+
nn.ReLU(),
|
96 |
+
nn.Dropout(dropout_percentage),
|
97 |
+
)
|
98 |
+
self.l2 = nn.Sequential(
|
99 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1), # 16x16x256 | 18
|
100 |
+
nn.MaxPool2d(2, 2), # 8x8x256 | 19
|
101 |
+
self.norm(256),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Dropout(dropout_percentage),
|
104 |
+
)
|
105 |
+
self.l3 = nn.Sequential(
|
106 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1), # 8x8x512 | 27
|
107 |
+
nn.MaxPool2d(2, 2), # 4x4x512 | 28
|
108 |
+
self.norm(512),
|
109 |
+
nn.ReLU(),
|
110 |
+
nn.Dropout(dropout_percentage),
|
111 |
+
)
|
112 |
+
self.l3res = nn.Sequential(
|
113 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 36
|
114 |
+
self.norm(512),
|
115 |
+
nn.ReLU(),
|
116 |
+
nn.Dropout(dropout_percentage),
|
117 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), # 4x4x512 | 44
|
118 |
+
self.norm(512),
|
119 |
+
nn.ReLU(),
|
120 |
+
nn.Dropout(dropout_percentage),
|
121 |
+
)
|
122 |
+
self.maxpool = nn.MaxPool2d(4, 4)
|
123 |
+
|
124 |
+
# Classifier
|
125 |
+
self.linear = nn.Linear(512, 10)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
x = self.prep_layer(x)
|
129 |
+
x = self.l1(x)
|
130 |
+
x = x + self.l1res(x)
|
131 |
+
x = self.l2(x)
|
132 |
+
x = self.l3(x)
|
133 |
+
x = x + self.l3res(x)
|
134 |
+
x = self.maxpool(x)
|
135 |
+
x = x.view(-1, 512)
|
136 |
+
x = self.linear(x)
|
137 |
+
return F.log_softmax(x, dim=1)
|
138 |
+
|
139 |
+
def training_step(self, batch, batch_idx):
|
140 |
+
data, target = batch
|
141 |
+
|
142 |
+
#forward pass
|
143 |
+
pred = self.forward(data)
|
144 |
+
|
145 |
+
#calculate loss
|
146 |
+
loss = self.criterion(pred, target)
|
147 |
+
|
148 |
+
#calculate accuracy
|
149 |
+
accuracy = self.accuracy(pred, target)
|
150 |
+
|
151 |
+
#log metrics
|
152 |
+
self.log_dict(
|
153 |
+
{"train_loss": loss, "train_acc": accuracy},
|
154 |
+
on_step=True,
|
155 |
+
on_epoch=True,
|
156 |
+
prog_bar=True,
|
157 |
+
logger=True,
|
158 |
+
)
|
159 |
+
return loss
|
160 |
+
|
161 |
+
|
162 |
+
def validation_step(self, batch, batch_idx):
|
163 |
+
data, target = batch
|
164 |
+
|
165 |
+
#forward pass
|
166 |
+
pred = self.forward(data)
|
167 |
+
|
168 |
+
#calculate loss
|
169 |
+
loss = self.criterion(pred, target)
|
170 |
+
|
171 |
+
#calculate accuracy
|
172 |
+
accuracy = self.accuracy(pred, target)
|
173 |
+
|
174 |
+
#log metrics
|
175 |
+
self.log_dict(
|
176 |
+
{"val_loss": loss, "val_acc": accuracy},
|
177 |
+
on_step=True,
|
178 |
+
on_epoch=True,
|
179 |
+
prog_bar=True,
|
180 |
+
logger=True,
|
181 |
+
)
|
182 |
+
return loss
|
183 |
+
|
184 |
+
def test_step(self, batch, batch_idx):
|
185 |
+
data, target = batch
|
186 |
+
|
187 |
+
#forward pass
|
188 |
+
pred = self.forward(data)
|
189 |
+
argmax_pred = pred.argmax(dim=1).cpu() # why cpu here when down
|
190 |
+
|
191 |
+
#calculate loss
|
192 |
+
loss = self.criterion(pred, target)
|
193 |
+
|
194 |
+
#calculate accuracy
|
195 |
+
accuracy = self.accuracy(pred, target)
|
196 |
+
|
197 |
+
#update confusion matrix
|
198 |
+
self.confusion_matrix.update(pred, target)
|
199 |
+
|
200 |
+
#log metrics
|
201 |
+
self.log_dict(
|
202 |
+
{"test_loss": loss, "test_acc": accuracy},
|
203 |
+
on_step=True,
|
204 |
+
on_epoch=True,
|
205 |
+
prog_bar=True,
|
206 |
+
logger=True,
|
207 |
+
)
|
208 |
+
|
209 |
+
#store the predictions. labels and incorrect predictions
|
210 |
+
|
211 |
+
#converting to cpu
|
212 |
+
data, target, pred, argmax_pred = data.cpu(), target.cpu(), pred.cpu(), argmax_pred.cpu()
|
213 |
+
|
214 |
+
#storing the predictions
|
215 |
+
self.pred_store["test_preds"] = torch.cat((self.pred_store["test_preds"], argmax_pred), dim=0)
|
216 |
+
self.pred_store["test_labels"] = torch.cat((self.pred_store["test_labels"], target), dim=0)
|
217 |
+
|
218 |
+
for d, t, p, o in zip(data, target, argmax_pred, pred):
|
219 |
+
if p.eq(t.view_as(p)).item() == False:
|
220 |
+
self.pred_store["test_incorrect"].append(
|
221 |
+
(d.cpu(), t, p, o[p.item()].cpu())
|
222 |
+
)
|
223 |
+
|
224 |
+
return loss
|
225 |
+
|
226 |
+
def find_bestLR_LRFinder(self, optimizer):
|
227 |
+
|
228 |
+
lr_finder = LRFinder(self, optimizer, criterian = self.criterion)
|
229 |
+
lr_finder.range_test(
|
230 |
+
self.trainer.datamodule.train_dataloader(),
|
231 |
+
end_lr=config.LRFINDER_END_LR,
|
232 |
+
num_iter=config.LRFINDER_NUM_ITERATIONS,
|
233 |
+
step_mode=config.LRFINDER_STEP_MODE
|
234 |
+
)
|
235 |
+
# best_lr = None
|
236 |
+
# Extract the loss and learning rate from history
|
237 |
+
loss = np.array(lr_finder.history['loss'])
|
238 |
+
lr = np.array(lr_finder.history['lr'])
|
239 |
+
|
240 |
+
# Find the learning rate with steepest negative gradient
|
241 |
+
gradient = np.gradient(loss)
|
242 |
+
idx = np.argmin(gradient)
|
243 |
+
best_lr = lr[idx]
|
244 |
+
|
245 |
+
try:
|
246 |
+
_, y = lr_finder.plot()
|
247 |
+
except Exception as e:
|
248 |
+
pass
|
249 |
+
|
250 |
+
print("BEST_LR: ", best_lr)
|
251 |
+
lr_finder.reset()
|
252 |
+
|
253 |
+
return best_lr
|
254 |
+
|
255 |
+
def configure_optimizers(self):
|
256 |
+
optimizer = self.get_only_optimizer()
|
257 |
+
best_lr = self.find_bestLR_LRFinder(optimizer)
|
258 |
+
scheduler = OneCycleLR(
|
259 |
+
optimizer,
|
260 |
+
max_lr=best_lr, #used best_lr insted of hard coded values
|
261 |
+
steps_per_epoch=len(self.trainer.datamodule.train_dataloader()),
|
262 |
+
epochs=config.NUM_EPOCHS,
|
263 |
+
pct_start=5 / config.NUM_EPOCHS,
|
264 |
+
div_factor=config.OCLR_DIV_FACTOR,
|
265 |
+
three_phase=config.OCLR_THREE_PHASE,
|
266 |
+
final_div_factor=config.OCLR_FINAL_DIV_FACTOR,
|
267 |
+
anneal_strategy=config.OCLR_ANNEAL_STRATEGY
|
268 |
+
)
|
269 |
+
|
270 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
|
271 |
+
|
272 |
+
def get_only_optimizer(self):
|
273 |
+
optimizer = optim.Adam(
|
274 |
+
self.parameters(),lr=self.learning_rate, weight_decay=self.weight_decay
|
275 |
+
)
|
276 |
+
return optimizer
|
277 |
+
|
278 |
+
def on_test_end(self) -> None:
|
279 |
+
super().on_test_end()
|
280 |
+
|
281 |
+
#Confusion Matrix
|
282 |
+
confmat = self.confusion_matrix.cpu().compute().numpy()
|
283 |
+
if config.NORM_CONF_MAT:
|
284 |
+
df_confmat = pd.DataFrame(
|
285 |
+
confmat / np.sum(confmat, axis=1)[:, None],
|
286 |
+
index=[i for i in config.CLASSES],
|
287 |
+
columns=[i for i in config.CLASSES],
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
df_confmat = pd.DataFrame(
|
291 |
+
confmat,
|
292 |
+
index=[i for i in config.CLASSES],
|
293 |
+
columns=[i for i in config.CLASSES],
|
294 |
+
)
|
295 |
+
plt.figure(figsize=(7, 5))
|
296 |
+
sns.heatmap(df_confmat, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
|
297 |
+
plt.tight_layout()
|
298 |
+
plt.ylabel("True label")
|
299 |
+
plt.xlabel("Predicted label")
|
300 |
+
plt.show()
|
301 |
+
|
302 |
+
def plot_incorrect_predictions_helper(self, num_imgs=10):
|
303 |
+
plot_incorrect_preds(
|
304 |
+
self.pred_store["test_incorrect"], config.CLASSES, num_imgs
|
305 |
+
)
|
Utilities/transforms.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
from albumentations.pytorch import ToTensorV2
|
3 |
+
|
4 |
+
# Define the transforms (only test)
|
5 |
+
|
6 |
+
test_transforms = A.Compose([
|
7 |
+
|
8 |
+
A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
|
9 |
+
ToTensorV2()
|
10 |
+
|
11 |
+
])
|
Utilities/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch_grad_cam import GradCAM
|
3 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
4 |
+
|
5 |
+
from . import config
|
6 |
+
from .transforms import test_transforms
|
7 |
+
|
8 |
+
|
9 |
+
def generate_confidences(
|
10 |
+
model,
|
11 |
+
input_img,
|
12 |
+
num_top_preds,
|
13 |
+
):
|
14 |
+
input_img = test_transforms(image=input_img)
|
15 |
+
input_img = input_img["image"]
|
16 |
+
|
17 |
+
input_img = input_img.unsqueeze(0)
|
18 |
+
model.eval()
|
19 |
+
log_probs = model(input_img)[0].detach()
|
20 |
+
model.train()
|
21 |
+
probs = torch.exp(log_probs)
|
22 |
+
|
23 |
+
confidences = {
|
24 |
+
config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
|
25 |
+
}
|
26 |
+
# Select top 5 confidences based on value
|
27 |
+
confidences = {
|
28 |
+
k: v
|
29 |
+
for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
|
30 |
+
:num_top_preds
|
31 |
+
]
|
32 |
+
}
|
33 |
+
return input_img, confidences
|
34 |
+
|
35 |
+
|
36 |
+
def generate_gradcam(
|
37 |
+
model,
|
38 |
+
org_img,
|
39 |
+
input_img,
|
40 |
+
show_gradcam,
|
41 |
+
gradcam_layer,
|
42 |
+
gradcam_opacity,
|
43 |
+
):
|
44 |
+
if show_gradcam:
|
45 |
+
if gradcam_layer == -1:
|
46 |
+
target_layers = [model.l3[-1]]
|
47 |
+
elif gradcam_layer == -2:
|
48 |
+
target_layers = [model.l2[-1]]
|
49 |
+
|
50 |
+
cam = GradCAM(
|
51 |
+
model=model,
|
52 |
+
target_layers=target_layers,
|
53 |
+
)
|
54 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
55 |
+
grayscale_cam = grayscale_cam[0, :]
|
56 |
+
|
57 |
+
visualization = show_cam_on_image(
|
58 |
+
org_img / 255,
|
59 |
+
grayscale_cam,
|
60 |
+
use_rgb=True,
|
61 |
+
image_weight=(1 - gradcam_opacity),
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
visualization = None
|
65 |
+
return visualization
|
66 |
+
|
67 |
+
|
68 |
+
def generate_missclassified_imgs(
|
69 |
+
model,
|
70 |
+
show_misclassified,
|
71 |
+
num_misclassified,
|
72 |
+
):
|
73 |
+
if show_misclassified:
|
74 |
+
plot = model.plot_incorrect_predictions_helper(num_misclassified)
|
75 |
+
else:
|
76 |
+
plot = None
|
77 |
+
return plot
|
Utilities/visualize.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from torchvision import transforms
|
3 |
+
import random as rand
|
4 |
+
|
5 |
+
def plot_incorrect_preds(incorrect, classes, num_imgs):
|
6 |
+
# num_imgs is a multiple of 5
|
7 |
+
assert num_imgs % 5 == 0
|
8 |
+
assert len(incorrect) >= num_imgs
|
9 |
+
|
10 |
+
incorrect_inds = rand.sample(range(len(incorrect)), num_imgs)
|
11 |
+
|
12 |
+
# incorrect (data, target, pred, output)
|
13 |
+
fig = plt.figure(figsize=(10, num_imgs // 2))
|
14 |
+
plt.suptitle("Target | Predicted Label")
|
15 |
+
for i in range(num_imgs):
|
16 |
+
cur_incorrect = incorrect[incorrect_inds[i]]
|
17 |
+
plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")
|
18 |
+
|
19 |
+
# unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
|
20 |
+
unnormalized = transforms.Normalize(
|
21 |
+
(-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
|
22 |
+
)(cur_incorrect[i][0])
|
23 |
+
plt.imshow(transforms.ToPILImage()(unnormalized))
|
24 |
+
plt.title(
|
25 |
+
f"{classes[cur_incorrect[i][1].item()]}|{classes[cur_incorrect[i][2].item()]}",
|
26 |
+
# fontsize=8,
|
27 |
+
)
|
28 |
+
plt.xticks([])
|
29 |
+
plt.yticks([])
|
30 |
+
plt.tight_layout()
|
31 |
+
return fig
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from Utilities.model import Net
|
5 |
+
from Utilities import config
|
6 |
+
from Utilities.utils import generate_confidences, generate_gradcam, generate_missclassified_imgs
|
7 |
+
|
8 |
+
inputs = [
|
9 |
+
|
10 |
+
gr.Image(shape=(32, 32), label="Input Image"),
|
11 |
+
gr.Slider(minimum=1, maximum=10, step=1, label="Number of Top Prediction to Display"),
|
12 |
+
gr.Checkbox(default=False, label="Show GradCAM"),
|
13 |
+
gr.Slider(minimum=-2, maximum=-1, step=1, value=-1, label="GradCAM Layer (from the end)"),
|
14 |
+
gr.Slider(minimum=0, maximum=1, value=0.5, label="GradCAM Heatmap Opacity"),
|
15 |
+
gr.Checkbox(label="Show Incorrect Predictions"),
|
16 |
+
gr.Slider(minimum=5, maximum=50, step=5, label="Number of Incorrect Predictions to Display"),
|
17 |
+
|
18 |
+
]
|
19 |
+
|
20 |
+
model = Net(
|
21 |
+
num_classes=config.NUM_CLASSES,
|
22 |
+
dropout_percentage = config.DROPOUT_PERCENTAGE,
|
23 |
+
norm = config.LAYER_NORM,
|
24 |
+
criterion = config.CRITERION,
|
25 |
+
learning_rate = config.LEARNING_RATE,
|
26 |
+
weight_decay = config.WEIGHT_DECAY
|
27 |
+
)
|
28 |
+
|
29 |
+
model.load_state_dict(
|
30 |
+
torch.load(
|
31 |
+
config.MODEL_PATH,
|
32 |
+
map_location=torch.device(config.ACCELERATOR)
|
33 |
+
)
|
34 |
+
)
|
35 |
+
|
36 |
+
model.pred_store = torch.load(config.PRED_STORE_PATH, map_location=torch.device(config.ACCELERATOR))
|
37 |
+
|
38 |
+
def generate_gradio_output(
|
39 |
+
input_img,
|
40 |
+
num_top_preds,
|
41 |
+
show_gradcam,
|
42 |
+
gradcam_layer,
|
43 |
+
gradcam_opacity,
|
44 |
+
show_misclassified,
|
45 |
+
num_misclassified,
|
46 |
+
):
|
47 |
+
processed_img, confidences = generate_confidences(
|
48 |
+
model=model,
|
49 |
+
input_img=input_img,
|
50 |
+
num_top_preds=num_top_preds
|
51 |
+
)
|
52 |
+
|
53 |
+
visulization = generate_gradcam(
|
54 |
+
model=model,
|
55 |
+
org_img=input_img,
|
56 |
+
input_img=processed_img,
|
57 |
+
show_gradcam=show_gradcam,
|
58 |
+
gradcam_layer=gradcam_layer,
|
59 |
+
gradcam_opacity=gradcam_opacity,
|
60 |
+
)
|
61 |
+
|
62 |
+
plot = generate_missclassified_imgs(
|
63 |
+
model=model,
|
64 |
+
show_misclassified=show_misclassified,
|
65 |
+
num_misclassified=num_misclassified,
|
66 |
+
)
|
67 |
+
|
68 |
+
return confidences, visulization, plot
|
69 |
+
|
70 |
+
outputs = [
|
71 |
+
gr.Label(visible=True, scale=0.5, label="Classification Confidences"),
|
72 |
+
gr.Image(shape=(32, 32), label="GradCAM Visualization").style(
|
73 |
+
width=256, height=256, visible=True
|
74 |
+
),
|
75 |
+
gr.Plot(visible=True, label="Misclassified Images")
|
76 |
+
]
|
77 |
+
|
78 |
+
examples = [
|
79 |
+
[config.EXAMPLE_IMG_PATH + "cat.jpeg", 3, True, -2, 0.68, True, 40],
|
80 |
+
[config.EXAMPLE_IMG_PATH + "horse.jpg", 3, True, -2, 0.59, True, 25],
|
81 |
+
[config.EXAMPLE_IMG_PATH + "bird.webp", 10, True, -1, 0.55, True, 20],
|
82 |
+
[config.EXAMPLE_IMG_PATH + "dog1.jpg", 10, True, -1, 0.33, True, 45],
|
83 |
+
[config.EXAMPLE_IMG_PATH + "frog1.webp", 5, True, -1, 0.64, True, 40],
|
84 |
+
[config.EXAMPLE_IMG_PATH + "deer.webp", 1, True, -2, 0.45, True, 20],
|
85 |
+
[config.EXAMPLE_IMG_PATH + "airplane.png", 3, True, -2, 0.43, True, 40],
|
86 |
+
[config.EXAMPLE_IMG_PATH + "shipp.jpg", 7, True, -1, 0.6, True, 30],
|
87 |
+
[config.EXAMPLE_IMG_PATH + "car.jpg", 2, True, -1, 0.68, True, 30],
|
88 |
+
[config.EXAMPLE_IMG_PATH + "truck1.jpg", 5, True, -2, 0.51, True, 35],
|
89 |
+
]
|
90 |
+
|
91 |
+
title = "Image Classification (CIFAR10 - 10 Classes) with GradCAM"
|
92 |
+
description = """A simple Gradio interface to visualize the output of a CNN trained on CIFAR10 dataset with GradCAM and Misclassified images.
|
93 |
+
The architecture is inspired from David Page's (myrtle.ai) DAWNBench winning model archiecture.
|
94 |
+
Please input the image and select the number of top predictions to display - you will see the top predictions and their corresponding confidence scores.
|
95 |
+
You can also select whether to show GradCAM for the particular image (utilizes the gradients of the classification score with respect to the final convolutional feature map, to identify the parts of an input image that most impact the classification score).
|
96 |
+
You need to select the model layer where the gradients need to be plugged from - this affects how much of the image is used to compute the GradCAM.
|
97 |
+
You can also select whether to show misclassified images - these are the images that the model misclassified.
|
98 |
+
Some examples are provided in the examples tab.
|
99 |
+
"""
|
100 |
+
|
101 |
+
gr.Interface(
|
102 |
+
fn=generate_gradio_output,
|
103 |
+
inputs=inputs,
|
104 |
+
outputs=outputs,
|
105 |
+
title=title,
|
106 |
+
description=description,
|
107 |
+
examples=examples
|
108 |
+
).launch()
|