本文最后更新于 2024-12-20,文章内容可能已经过时。

实现热力图可视化的三种方式:

  • 普通热力图

  • 仅标签框内显示热力图

  • 只显示热力高的部分

下面是展示的原代码,这个可以不用更改源代码的结构来绘制热力图,就一个文件即可。之前尝试过另一种方式要更改Detection Model文件的结构,还要添加两三个文件,但是效果不怎么好。

import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import numpy as np

np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from models.yolo_hyperspectral import Model
from utils.general import intersect_dicts
from utils.augmentations_hyperspectral import letterbox
from utils.general import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
import tifffile as tiff


def hyperspectral_to_rgb(hyperspectral_image, red_band_index=45, green_band_index=29, blue_band_index=14):
    """
    将高光谱图像转换为RGB格式。

    :param hyperspectral_image: 高光谱图像(numpy数组),形状为 (H, W, C)
    :param red_band_index: 红波段的索引(默认45)
    :param green_band_index: 绿波段的索引(默认29)
    :param blue_band_index: 蓝波段的索引(默认14)
    :return: 转换后的RGB图像(numpy数组),形状为 (H, W, 3)
    """
    # 检查输入图像的形状
    if hyperspectral_image.ndim != 3:
        raise ValueError("高光谱图像必须是三维数组,形状为 (H, W, C)")

    H, W, C = hyperspectral_image.shape

    # 检查波段索引是否在有效范围内
    if not (0 <= red_band_index < C):
        raise IndexError(f"red_band_index {red_band_index} 超出范围 (0, {C-1})")
    if not (0 <= green_band_index < C):
        raise IndexError(f"green_band_index {green_band_index} 超出范围 (0, {C-1})")
    if not (0 <= blue_band_index < C):
        raise IndexError(f"blue_band_index {blue_band_index} 超出范围 (0, {C-1})")

    # 提取红、绿、蓝波段
    red_band = hyperspectral_image[:, :, red_band_index]    # 红波段,形状 (H, W)
    green_band = hyperspectral_image[:, :, green_band_index]  # 绿波段,形状 (H, W)
    blue_band = hyperspectral_image[:, :, blue_band_index]   # 蓝波段,形状 (H, W)

    # 堆叠成RGB图像,形状 (H, W, 3)
    rgb_image = np.stack((red_band, green_band, blue_band), axis=-1)

    # 归一化整个图像到 [0, 1] 范围
    min_val = np.min(rgb_image)
    max_val = np.max(rgb_image)
    if max_val - min_val == 0:
        raise ValueError("图像的最小值和最大值相同,无法归一化。")
    rgb_image_normalized = (rgb_image - min_val) / (max_val - min_val)

    # 将归一化后的图像转换到 [0, 255] 范围并转换为 uint8
    rgb_image_normalized = np.clip(rgb_image_normalized * 255, 0, 255).astype(np.uint8)

    return rgb_image_normalized

def adjust_brightness(rgb_image, brightness_factor=1.2):
    """
    调整RGB图像的亮度。

    :param rgb_image: RGB图像,范围为 [0, 255],形状为 (H, W, 3),类型 uint8
    :param brightness_factor: 亮度因子(>1 为增亮,<1 为减亮)
    :return: 调整亮度后的RGB图像,范围为 [0, 255],类型为 uint8
    """
    # 增加亮度
    brightened_image = rgb_image.astype(np.float32) * brightness_factor

    # 将值限制在 [0, 255] 范围内
    brightened_image = np.clip(brightened_image, 0, 255)

    # 转换为 uint8
    brightened_image_uint8 = brightened_image.astype(np.uint8)

    return brightened_image_uint8

class yolov5_heatmap:
    def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):
        device = torch.device(device)
        ckpt = torch.load(weight)
        model_names = ckpt['model'].names
        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
        model = Model(cfg, ch=112, nc=len(model_names)).to(device)
        csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersect
        model.load_state_dict(csd, strict=False)  # load
        model.eval()
        print(f'Transferred {len(csd)}/{len(model.state_dict())} items')

        target_layers = [eval(layer)]
        method = eval(method)

        colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int_)
        self.__dict__.update(locals())

    def post_process(self, result):
        logits_ = result[..., 4:]
        boxes_ = result[..., :4]
        sorted, indices = torch.sort(logits_[..., 0], descending=True)
        return logits_[0][indices[0]], xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()

    def draw_detections(self, box, color, name, img):
        xmin, ymin, xmax, ymax = list(map(int, list(box)))
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
        cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2,
                    lineType=cv2.LINE_AA)
        return img

    def __call__(self, img_path, save_path):
        # remove dir if exist
        if os.path.exists(save_path):
            shutil.rmtree(save_path)
        # make dir if not exist
        os.makedirs(save_path, exist_ok=True)

        # img process
        # img = cv2.imread(img_path)
        img = tiff.imread(img_path)
        # Convert shape from (C, H, W) to (H, W, C)
        img = np.transpose(img, (1, 2, 0))
        img = np.ascontiguousarray(img)  # Ensure contiguous memory layout
        img = letterbox(img)[0]
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.float32(img) / 255.0
        tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)

        # init ActivationsAndGradients
        grads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)

        # get ActivationsAndResult
        result = grads(tensor)
        activations = grads.activations[0].cpu().detach().numpy()

        # postprocess to yolo output
        post_result, post_boxes = self.post_process(result[0])
        for i in trange(int(post_result.size(0) * self.ratio)):
            if post_result[i][0] < self.conf_threshold:
                break

            self.model.zero_grad()
            if self.backward_type == 'conf':
                post_result[i, 0].backward(retain_graph=True)
            else:
                # get max probability for this prediction
                score = post_result[i, 1:].max()
                score.backward(retain_graph=True)

            # process heatmap
            gradients = grads.gradients[0]
            b, k, u, v = gradients.size()
            weights = self.method.get_cam_weights(self.method, None, None, None, activations,
                                                  gradients.detach().numpy())
            weights = weights.reshape((b, k, 1, 1))
            saliency_map = np.sum(weights * activations, axis=1)
            saliency_map = np.squeeze(np.maximum(saliency_map, 0))
            saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))
            saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
            if (saliency_map_max - saliency_map_min) == 0:
                continue
            saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)

            # 读取并处理rgb图像
            rgb_img = hyperspectral_to_rgb(img, red_band_index=45, green_band_index=29, blue_band_index=14)
            rgb_img = adjust_brightness(rgb_img, brightness_factor=2)
            rgb_img = letterbox(rgb_img)[0]
            rgb_img = np.float32(rgb_img) / 255.0

            # add heatmap and box to image
            cam_image = show_cam_on_image(rgb_img, saliency_map, use_rgb=True)
            # cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())],
            #                                  f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}',
            #                                  cam_image)
            cam_image = Image.fromarray(cam_image)
            cam_image.save(f'{save_path}/{i}.png')


def get_params():
    params = {
        'weight': '/home7/zy/yolov5-7.0/weights/SZU640_Extend.pt',
        'cfg': '/home7/zy/yolov5-7.0/models/models_hyperspectral/onestream/TwoBranch/MixTitos_TFBCSSM_C3SS2.yaml',
        'device': 'cuda:0',
        'method': 'GradCAM',  # GradCAMPlusPlus, GradCAM, XGradCAM
        'layer': 'model.model[45]',
        'backward_type': 'conf',  # class or conf
        'conf_threshold': 0.6,  # 0.6
        'ratio': 0.02  # 0.02-0.1
    }
    return params


if __name__ == '__main__':
    model = yolov5_heatmap(**get_params())
    model(r'/home7/zy/yolov5-7.0/dataset/SZU640_Extend/hs_information/images/val2/201.tiff', 'result')