
# coding: utf-8
import cv2
import numpy as np
import torch
# 类的作用
# 1.编写梯度获取hook
# 2.网络层上注册hook
# 3.运行网络forward backward
# 4.根据梯度和特征输出热力图
class ShowGradCam:
    def __init__(self,conv_layer):
        assert isinstance(conv_layer,torch.nn.Module), "input layer should be torch.nn.Module"
        self.conv_layer = conv_layer
        self.conv_layer.register_forward_hook(self.farward_hook)
        self.conv_layer.register_backward_hook(self.backward_hook)
        self.grad_res = []
        self.feature_res = []
    def backward_hook(self, module, grad_in, grad_out):
        self.grad_res.append(grad_out[0].detach())
    def farward_hook(self,module, input, output):
        self.feature_res.append(output)
    def gen_cam(self, feature_map, grads):
        """
        依据梯度和特征图,生成cam
        :param feature_map: np.array, in [C, H, W]
        :param grads: np.array, in [C, H, W]
        :return: np.array, [H, W]
        """
        cam = np.zeros(feature_map.shape[1:], dtype=np.float32)  # cam shape (H, W)
        weights = np.mean(grads, axis=(1, 2))  #
        for i, w in enumerate(weights):
            cam += w * feature_map[i, :, :]
        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, (32, 32))
        cam -= np.min(cam)
        cam /= np.max(cam)
        return cam
    def show_on_img(self,input_img):
        '''
        write heatmap on target img
        :param input_img: cv2:ndarray/img_pth
        :return: save jpg
        '''
        if isinstance(input_img,str):
            input_img = cv2.imread(input_img)
        img_size = (input_img.shape[1],input_img.shape[0])
        fmap = self.feature_res[0].cpu().data.numpy().squeeze()
        grads_val = self.grad_res[0].cpu().data.numpy().squeeze()
        cam = self.gen_cam(fmap, grads_val)
        cam = cv2.resize(cam, img_size)
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)/255.
        cam = heatmap + np.float32(input_img/255.)
        cam = cam / np.max(cam)*255
        cv2.imwrite('grad_feature.jpg',cam)
        print('save gradcam result in grad_feature.jpg')