import torch import numpy as np import cv2 from chexnet import ChexNet from layers import SaveFeature from constant import CLASS_NAMES class HeatmapGenerator: # def __init__(self, model_name='20180429-130928', mode=None): def __init__(self, chexnet, mode=None): self.chexnet = chexnet self.sf = SaveFeature(chexnet.backbone) self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0] self.mapping = self.cam if mode == 'cam' else self.default def cam(self, pred_y): heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy() return heatmap # def default(self, pred_ys): # return torch.max(torch.abs(self.sf.features), dim=1)[0] def generate(self, image): prob = self.chexnet.predict(image) w, h = image.size return self.from_prob(prob, w, h) def from_prob(self, prob, w, h): pred_y = np.argmax(prob) heatmap = self.mapping(pred_y) heatmap = heatmap - np.min(heatmap) heatmap = heatmap / np.max(heatmap) heatmap = cv2.resize(heatmap, (w, h)) return heatmap, CLASS_NAMES[pred_y]