Manu101 commited on
Commit
853bc3f
1 Parent(s): 47f1ab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -18
app.py CHANGED
@@ -40,39 +40,78 @@ def resize_image_pil(image, new_width, new_height):
40
 
41
  return resized
42
 
43
- def inference(input_img, transparency):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  transform = transforms.ToTensor()
45
  input_img = transform(input_img)
46
- input_img = input_img.to(device)
47
  input_img = input_img.unsqueeze(0)
48
  outputs = model(input_img)
 
 
 
 
49
  _, prediction = torch.max(outputs, 1)
50
- target_layers = [model.layer2[-2]]
51
- cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
52
- grayscale_cam = cam(input_tensor=input_img, targets=targets)
53
  grayscale_cam = grayscale_cam[0, :]
54
- img = input_img.squeeze(0).to('cpu')
55
- img = inv_normalize(img)
56
- rgb_img = np.transpose(img, (1, 2, 0))
57
- rgb_img = rgb_img.numpy()
58
- visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
59
- return classes[prediction[0].item()], visualization
 
 
 
 
 
60
 
61
  demo = gr.Interface(
62
- inference,
63
- inputs = [
64
  gr.Image(width=256, height=256, label="Input Image"),
65
- gr.Slider(0, 1, value=0.5, label="Overall opacity fo the overlay"),
66
- gr.Slider(-2, -1, value=-2, step=1, label="Which GradCAM layer?")
67
  ],
68
  outputs = [
69
  "text",
70
  gr.Image(width=256, height=256, label="Output"),
71
  gr.Label(num_top_classes=3)
72
  ],
73
- title="CIFAR10 trained on ResNet18 with GradCAM feature",
74
- description = "A simple Gradio app for checking GradCAM outputs from results of ResNet18 model.",
75
- examples = [["cat.jpg", 0.5, -1], ["dog.jpg", 0.7, -2]]
 
 
 
 
 
 
76
  )
77
 
78
  demo.launch()
 
40
 
41
  return resized
42
 
43
+ # def inference(input_img, transparency):
44
+ # transform = transforms.ToTensor()
45
+ # input_img = transform(input_img)
46
+ # input_img = input_img.to(device)
47
+ # input_img = input_img.unsqueeze(0)
48
+ # outputs = model(input_img)
49
+ # _, prediction = torch.max(outputs, 1)
50
+ # target_layers = [model.layer2[-2]]
51
+ # cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
52
+ # grayscale_cam = cam(input_tensor=input_img, targets=targets)
53
+ # grayscale_cam = grayscale_cam[0, :]
54
+ # img = input_img.squeeze(0).to('cpu')
55
+ # img = inv_normalize(img)
56
+ # rgb_img = np.transpose(img, (1, 2, 0))
57
+ # rgb_img = rgb_img.numpy()
58
+ # visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
59
+ # return classes[prediction[0].item()], visualization
60
+
61
+ def inference(input_img, transparency=0.5, target_layer_number=-1):
62
+ input_img = resize_image_pil(input_img, 32, 32)
63
+ input_img = np.array(input_img)
64
+ org_img= input_img
65
+
66
+ input_img = input_img.reshape((32, 32, 3))
67
+
68
  transform = transforms.ToTensor()
69
  input_img = transform(input_img)
70
+
71
  input_img = input_img.unsqueeze(0)
72
  outputs = model(input_img)
73
+
74
+ softmax = torch.nn.Softmax(dim=0)
75
+ o = softmax(outputs.flatten())
76
+ confidences = {classes[i] : float(o[i]) for i in range(10)}
77
  _, prediction = torch.max(outputs, 1)
78
+ target_layers = [model.layer2[target_layer_number]]
79
+ cam = GradCAM(model=model, target_layers = target_layers)
80
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
81
  grayscale_cam = grayscale_cam[0, :]
82
+ visualization = show_cam_on_image(
83
+ org_img/255,
84
+ grayscale_cam,
85
+ use_rgb=True,
86
+ image_weight=transparency
87
+ )
88
+
89
+ return classes[prediction[0].item()], visualization, confidences
90
+
91
+
92
+
93
 
94
  demo = gr.Interface(
95
+ fn=inference,
96
+ inputs=[
97
  gr.Image(width=256, height=256, label="Input Image"),
98
+ gr.Slider(0,1, value=0.5, label="Overall opacity value"),
99
+ gr.Slider(-2, -1, value=-2, label="Which model layer to use for GradCAM?")
100
  ],
101
  outputs = [
102
  "text",
103
  gr.Image(width=256, height=256, label="Output"),
104
  gr.Label(num_top_classes=3)
105
  ],
106
+
107
+ title="CIFAR10 trained on ResNet18 with GradCAM",
108
+
109
+ description = "A simple Gradio interface to infer on ResNet model with GradCAM results shown on top.",
110
+
111
+ examples = [
112
+ ["cat.jpg", 0.5, -1],
113
+ ["dog.jpg", 0.7, -2]
114
+ ]
115
  )
116
 
117
  demo.launch()