Update app.py
Browse files
app.py
CHANGED
@@ -40,39 +40,78 @@ def resize_image_pil(image, new_width, new_height):
|
|
40 |
|
41 |
return resized
|
42 |
|
43 |
-
def inference(input_img, transparency):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
transform = transforms.ToTensor()
|
45 |
input_img = transform(input_img)
|
46 |
-
|
47 |
input_img = input_img.unsqueeze(0)
|
48 |
outputs = model(input_img)
|
|
|
|
|
|
|
|
|
49 |
_, prediction = torch.max(outputs, 1)
|
50 |
-
target_layers = [model.layer2[
|
51 |
-
cam = GradCAM(model=model, target_layers=target_layers
|
52 |
-
grayscale_cam = cam(input_tensor=input_img, targets=
|
53 |
grayscale_cam = grayscale_cam[0, :]
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
demo = gr.Interface(
|
62 |
-
inference,
|
63 |
-
inputs
|
64 |
gr.Image(width=256, height=256, label="Input Image"),
|
65 |
-
gr.Slider(0,
|
66 |
-
gr.Slider(-2, -1, value=-2,
|
67 |
],
|
68 |
outputs = [
|
69 |
"text",
|
70 |
gr.Image(width=256, height=256, label="Output"),
|
71 |
gr.Label(num_top_classes=3)
|
72 |
],
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
)
|
77 |
|
78 |
demo.launch()
|
|
|
40 |
|
41 |
return resized
|
42 |
|
43 |
+
# def inference(input_img, transparency):
|
44 |
+
# transform = transforms.ToTensor()
|
45 |
+
# input_img = transform(input_img)
|
46 |
+
# input_img = input_img.to(device)
|
47 |
+
# input_img = input_img.unsqueeze(0)
|
48 |
+
# outputs = model(input_img)
|
49 |
+
# _, prediction = torch.max(outputs, 1)
|
50 |
+
# target_layers = [model.layer2[-2]]
|
51 |
+
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
|
52 |
+
# grayscale_cam = cam(input_tensor=input_img, targets=targets)
|
53 |
+
# grayscale_cam = grayscale_cam[0, :]
|
54 |
+
# img = input_img.squeeze(0).to('cpu')
|
55 |
+
# img = inv_normalize(img)
|
56 |
+
# rgb_img = np.transpose(img, (1, 2, 0))
|
57 |
+
# rgb_img = rgb_img.numpy()
|
58 |
+
# visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
59 |
+
# return classes[prediction[0].item()], visualization
|
60 |
+
|
61 |
+
def inference(input_img, transparency=0.5, target_layer_number=-1):
|
62 |
+
input_img = resize_image_pil(input_img, 32, 32)
|
63 |
+
input_img = np.array(input_img)
|
64 |
+
org_img= input_img
|
65 |
+
|
66 |
+
input_img = input_img.reshape((32, 32, 3))
|
67 |
+
|
68 |
transform = transforms.ToTensor()
|
69 |
input_img = transform(input_img)
|
70 |
+
|
71 |
input_img = input_img.unsqueeze(0)
|
72 |
outputs = model(input_img)
|
73 |
+
|
74 |
+
softmax = torch.nn.Softmax(dim=0)
|
75 |
+
o = softmax(outputs.flatten())
|
76 |
+
confidences = {classes[i] : float(o[i]) for i in range(10)}
|
77 |
_, prediction = torch.max(outputs, 1)
|
78 |
+
target_layers = [model.layer2[target_layer_number]]
|
79 |
+
cam = GradCAM(model=model, target_layers = target_layers)
|
80 |
+
grayscale_cam = cam(input_tensor=input_img, targets=None)
|
81 |
grayscale_cam = grayscale_cam[0, :]
|
82 |
+
visualization = show_cam_on_image(
|
83 |
+
org_img/255,
|
84 |
+
grayscale_cam,
|
85 |
+
use_rgb=True,
|
86 |
+
image_weight=transparency
|
87 |
+
)
|
88 |
+
|
89 |
+
return classes[prediction[0].item()], visualization, confidences
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
|
94 |
demo = gr.Interface(
|
95 |
+
fn=inference,
|
96 |
+
inputs=[
|
97 |
gr.Image(width=256, height=256, label="Input Image"),
|
98 |
+
gr.Slider(0,1, value=0.5, label="Overall opacity value"),
|
99 |
+
gr.Slider(-2, -1, value=-2, label="Which model layer to use for GradCAM?")
|
100 |
],
|
101 |
outputs = [
|
102 |
"text",
|
103 |
gr.Image(width=256, height=256, label="Output"),
|
104 |
gr.Label(num_top_classes=3)
|
105 |
],
|
106 |
+
|
107 |
+
title="CIFAR10 trained on ResNet18 with GradCAM",
|
108 |
+
|
109 |
+
description = "A simple Gradio interface to infer on ResNet model with GradCAM results shown on top.",
|
110 |
+
|
111 |
+
examples = [
|
112 |
+
["cat.jpg", 0.5, -1],
|
113 |
+
["dog.jpg", 0.7, -2]
|
114 |
+
]
|
115 |
)
|
116 |
|
117 |
demo.launch()
|