Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import cv2 | |
from chexnet import ChexNet | |
from layers import SaveFeature | |
from constant import CLASS_NAMES | |
class HeatmapGenerator: | |
# def __init__(self, model_name='20180429-130928', mode=None): | |
def __init__(self, chexnet, mode=None): | |
self.chexnet = chexnet | |
self.sf = SaveFeature(chexnet.backbone) | |
self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0] | |
self.mapping = self.cam if mode == 'cam' else self.default | |
def cam(self, pred_y): | |
heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy() | |
return heatmap | |
# def default(self, pred_ys): | |
# return torch.max(torch.abs(self.sf.features), dim=1)[0] | |
def generate(self, image): | |
prob = self.chexnet.predict(image) | |
w, h = image.size | |
return self.from_prob(prob, w, h) | |
def from_prob(self, prob, w, h): | |
pred_y = np.argmax(prob) | |
heatmap = self.mapping(pred_y) | |
heatmap = heatmap - np.min(heatmap) | |
heatmap = heatmap / np.max(heatmap) | |
heatmap = cv2.resize(heatmap, (w, h)) | |
return heatmap, CLASS_NAMES[pred_y] | |