jkang commited on
Commit
4f9e21b
1 Parent(s): e3d486c

Upload gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +141 -0
gradcam_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Grad-CAM visualization utilities
3
+
4
+ - Based on https://keras.io/examples/vision/grad_cam/
5
+
6
+ ---
7
+ - 2021-12-18 jkang first created
8
+ - 2022-01-16
9
+ - copied from https://huggingface.co/spaces/jkang/demo-gradcam-imagenet/blob/main/utils.py
10
+ - updated for artis/trend classifier
11
+ '''
12
+ import matplotlib.cm as cm
13
+
14
+ import os
15
+ import re
16
+ from glob import glob
17
+ import numpy as np
18
+ import tensorflow as tf
19
+ tfk = tf.keras
20
+ K = tfk.backend
21
+
22
+ # Disable GPU for testing
23
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
24
+
25
+
26
+ def get_imagenet_classes():
27
+ '''Retrieve all 1000 imagenet classes/labels as dictionaries'''
28
+ classes = tfk.applications.imagenet_utils.decode_predictions(
29
+ np.expand_dims(np.arange(1000), 0), top=1000
30
+ )
31
+ idx2lab = {cla[2]: cla[1] for cla in classes[0]}
32
+ lab2idx = {idx2lab[idx]: idx for idx in idx2lab}
33
+ return idx2lab, lab2idx
34
+
35
+
36
+ def search_by_name(str_part):
37
+ '''Search imagenet class by partial matching string'''
38
+ results = [key for key in list(lab2idx.keys()) if re.search(str_part, key)]
39
+ if len(results) != 0:
40
+ return [(key, lab2idx[key]) for key in results]
41
+ else:
42
+ return []
43
+
44
+
45
+ def get_xception_model():
46
+ '''Get model to use'''
47
+ base_model = tfk.applications.xception.Xception
48
+ preprocessor = tfk.applications.xception.preprocess_input
49
+ decode_predictions = tfk.applications.xception.decode_predictions
50
+ last_conv_layer_name = "block14_sepconv2_act"
51
+
52
+ model = base_model(weights='imagenet')
53
+ grad_model = tfk.models.Model(
54
+ inputs=[model.inputs],
55
+ outputs=[model.get_layer(last_conv_layer_name).output,
56
+ model.output]
57
+ )
58
+ return model, grad_model, preprocessor, decode_predictions
59
+
60
+
61
+ def get_img_4d_array(image_file, image_size=(299, 299)):
62
+ '''Load image as 4d array'''
63
+ img = tfk.preprocessing.image.load_img(
64
+ image_file, target_size=image_size) # PIL obj
65
+ img_array = tfk.preprocessing.image.img_to_array(
66
+ img) # float32 numpy array
67
+ img_array = np.expand_dims(img_array, axis=0) # 3d -> 4d (1,299,299,3)
68
+ return img_array
69
+
70
+
71
+ def make_gradcam_heatmap(grad_model, img_array, pred_idx=None):
72
+ '''Generate heatmap to overlay with
73
+ - img_array: 4d numpy array
74
+ - pred_idx: eg. index out of 1000 imagenet classes
75
+ if None, argmax is chosen from prediction
76
+ '''
77
+ # Get gradient of pred class w.r.t. last conv activation
78
+ with tf.GradientTape() as tape:
79
+ last_conv_act, predictions = grad_model(img_array)
80
+ if pred_idx == None:
81
+ pred_idx = tf.argmax(predictions[0])
82
+ class_channel = predictions[:, pred_idx] # (1,1000) => (1,)
83
+
84
+ # d(class_channel/last_conv_act)
85
+ grads = tape.gradient(class_channel, last_conv_act)
86
+ pooled_grads = tf.reduce_mean(grads, axis=(
87
+ 0, 1, 2)) # (1,10,10,2048) => (2048,)
88
+
89
+ # (10,10,2048) x (2048,1) => (10,10,1)
90
+ heatmap = last_conv_act[0] @ pooled_grads[..., tf.newaxis]
91
+ heatmap = tf.squeeze(heatmap) # (10,10)
92
+
93
+ # Normalize heatmap between 0 and 1
94
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
95
+ return heatmap, pred_idx.numpy(), predictions.numpy().squeeze()
96
+
97
+
98
+ def align_image_with_heatmap(img_array, heatmap, alpha=0.3, cmap='jet'):
99
+ '''Align the image with gradcam heatmap
100
+ - img_array: 4d numpy array
101
+ - heatmap: output of `def make_gradcam_heatmap()` as 2d numpy array
102
+ '''
103
+ img_array = img_array.squeeze() # 4d => 3d
104
+
105
+ # Rescale to 0-255 range
106
+ heatmap_scaled = np.uint8(255 * heatmap)
107
+ img_array_scaled = np.uint8(255 * img_array)
108
+
109
+ colormap = cm.get_cmap(cmap)
110
+ colors = colormap(np.arange(256))[:, :3] # mapping RGB to heatmap
111
+ heatmap_colored = colors[heatmap_scaled] # ? still unclear
112
+
113
+ # Make RGB colorized heatmap
114
+ heatmap_colored = (tfk.preprocessing.image.array_to_img(heatmap_colored) # array => PIL
115
+ .resize((img_array.shape[1], img_array.shape[0])))
116
+ heatmap_colored = tfk.preprocessing.image.img_to_array(
117
+ heatmap_colored) # PIL => array
118
+
119
+ # Overlay image with heatmap
120
+ overlaid_img = heatmap_colored * alpha + img_array_scaled
121
+ overlaid_img = tfk.preprocessing.image.array_to_img(overlaid_img)
122
+ return overlaid_img
123
+
124
+
125
+ if __name__ == '__main__':
126
+ # Test GradCAM
127
+ examples = sorted(glob(os.path.join('examples', '*.jpg')))
128
+ idx2lab, lab2idx = get_imagenet_classes()
129
+
130
+ model, grad_model, preprocessor, decode_predictions = get_xception_model()
131
+
132
+ img_4d_array = get_img_4d_array(examples[0])
133
+ img_4d_array = preprocessor(img_4d_array)
134
+
135
+ heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=None)
136
+
137
+ img_pil = align_image_with_heatmap(
138
+ img_4d_array, heatmap, alpha=0.3, cmap='jet')
139
+
140
+ img_pil.save('test.jpg')
141
+ print('done')