YOLOv7输出层之间的热力图

2024-08-31 20:36
文章标签 输出 之间 力图 yolov7

本文主要是介绍YOLOv7输出层之间的热力图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

我们经常看到一些论文里绘制了不同的热力图,来直观的感受其模型的有效性。特别是使用了注意力模块的网络,热力图就可以验证注意力机制是否真正聚焦到了预期的重要特征上,以便对模型的有效性和合理性进行评估。

例如Centralized Feature Pyramid for Object Detection这篇文章中展示的,就很能够表达作者改进后的模型相比之前模型的一个优越性。

在这里插入图片描述
本文就来记录一下如何使用python脚本来输出YOLOv7层之间的热力图。

添加步骤

1️⃣ 在本地的YOLOv7项目的根目录下新建heatmap.py,将以下代码复制到其中

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch, yaml, cv2, os, shutil
import torch.nn as nn
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from tqdm import trange
from PIL import Image
from models.yolo import Model
from utils.torch_utils import intersect_dicts
from utils.datasets import letterbox
from utils.general import xywh2xyxy
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradientsclass yolov7_heatmap:def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):device = torch.device(device)ckpt = torch.load(weight)model_names = ckpt['model'].namescsd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32model = Model(cfg, ch=3, nc=len(model_names)).to(device)csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor'])  # intersectmodel.load_state_dict(csd, strict=False)  # loadmodel.eval()print(f'Transferred {len(csd)}/{len(model.state_dict())} items')target_layers = [eval(layer)]method = eval(method)colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)self.__dict__.update(locals())def post_process(self, result):boxes_ = result[0][..., :4]logits_ = []for data in result[1]:bs, n, w, h, _ = data.size()logits_.append(data.reshape((bs, n * w * h, _)))logits_ = torch.cat(logits_, dim=1)[..., 4:]sorted, indices = torch.sort(logits_[..., 0], descending=True)logits_ = logits_[0][indices[0]]logits_[:, 0] = torch.sigmoid(logits_[:, 0])return logits_, xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()def draw_detections(self, box, color, name, img):xmin, ymin, xmax, ymax = list(map(int, list(box)))cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)return imgdef __call__(self, img_path, save_path):# remove dir if existif os.path.exists(save_path):shutil.rmtree(save_path)# make dir if not existos.makedirs(save_path, exist_ok=True)# img processimg = cv2.imread(img_path)img = letterbox(img)[0]img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = np.float32(img) / 255.0tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)# init ActivationsAndGradientsgrads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)# get ActivationsAndResultresult = grads(tensor)activations = grads.activations[0].cpu().detach().numpy()# postprocess to yolo outputpost_result, post_boxes = self.post_process(result)for i in trange(int(post_result.size(0) * self.ratio)):if post_result[i][0] < self.conf_threshold:breakself.model.zero_grad()if self.backward_type == 'conf':post_result[i, 0].backward(retain_graph=True)else:# get max probability for this predictionscore = post_result[i, 1:].max()score.backward(retain_graph=True)# process heatmapgradients = grads.gradients[0]b, k, u, v = gradients.size()weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())weights = weights.reshape((b, k, 1, 1))saliency_map = np.sum(weights * activations, axis=1)saliency_map = np.squeeze(np.maximum(saliency_map, 0))saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()if (saliency_map_max - saliency_map_min) == 0:continuesaliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)# add heatmap and box to imagecam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)#cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)cam_image = Image.fromarray(cam_image)cam_image.save(f'{save_path}/{i}.png')def get_params():params = {'weight': 'runs/train/exp/weights/best.pt',  'cfg': 'cfg/training/yolov7_test.yaml','device': 'cuda:0','method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM'layer': 'model.model[-2]',  'backward_type': 'class', # class or conf'conf_threshold': 0.6, # 0.6'ratio': 0.02 # 0.02-0.1}return paramsif __name__ == '__main__':model = yolov7_heatmap(**get_params())model('inference/heat_image/001.jpg', 'heat_result')

2️⃣ 修改配置参数

文件中的主要参数配置如下:

在这里插入图片描述

参数解释
weight权重路径,训练完成后的权重文件
cfg模型文件路径,与权重所训练出来的模型文件一致
device运行的设备,和模型训练时的device参数设置一致
method可选择GradCAM,GradCAMPlusPlus和XGradCAM ,可以都试试,效果不同
layer想要输出第几层的热力图就写几,我这里写的的-2,即倒数第二层,可以多换换,看看效果
backward_type反向传播的计算类型,class表示按照类别最大概率进行计算 或 conf 通过置信度计算梯度
conf_threshold置信度阈值,设置成0.6
ratio取前多少数据,设置成0.02

在这里插入图片描述

箭头指向的数据就是行号。

3️⃣ 数据源

在这里插入图片描述
model('inference/heat_image/001.jpg', 'heat_result')中:

第一个参数inference/heat_image/001.jpg表示想要进行热力图绘制的原图像路径。

第二个参数'heat_result'表示绘制完成后输出的文件夹路径。

4️⃣ 调试

在这里插入图片描述
此时就已经绘制完成了,在指定的文件夹下就已经输出了热力图了。进度条还没有满就停止,是因为后面的目标已经不满足置信度conf_threshold的设定值。

这个进度条的长度151是之前设定的参数ratio的结果,其只会选择前0.02的目标进行热力图可视化。

博客参考链接
代码参考链接

这篇关于YOLOv7输出层之间的热力图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1124930

相关文章

java -jar example.jar 产生的日志输出到指定文件的方法

《java-jarexample.jar产生的日志输出到指定文件的方法》这篇文章给大家介绍java-jarexample.jar产生的日志输出到指定文件的方法,本文给大家介绍的非常详细,对大家的... 目录怎么让 Java -jar example.jar 产生的日志输出到指定文件一、方法1:使用重定向1、

Java中数组与栈和堆之间的关系说明

《Java中数组与栈和堆之间的关系说明》文章讲解了Java数组的初始化方式、内存存储机制、引用传递特性及遍历、排序、拷贝技巧,强调引用数据类型方法调用时形参可能修改实参,但需注意引用指向单一对象的特性... 目录Java中数组与栈和堆的关系遍历数组接下来是一些编程小技巧总结Java中数组与栈和堆的关系关于

在Java中实现线程之间的数据共享的几种方式总结

《在Java中实现线程之间的数据共享的几种方式总结》在Java中实现线程间数据共享是并发编程的核心需求,但需要谨慎处理同步问题以避免竞态条件,本文通过代码示例给大家介绍了几种主要实现方式及其最佳实践,... 目录1. 共享变量与同步机制2. 轻量级通信机制3. 线程安全容器4. 线程局部变量(ThreadL

Spring Boot集成/输出/日志级别控制/持久化开发实践

《SpringBoot集成/输出/日志级别控制/持久化开发实践》SpringBoot默认集成Logback,支持灵活日志级别配置(INFO/DEBUG等),输出包含时间戳、级别、类名等信息,并可通过... 目录一、日志概述1.1、Spring Boot日志简介1.2、日志框架与默认配置1.3、日志的核心作用

Javaee多线程之进程和线程之间的区别和联系(最新整理)

《Javaee多线程之进程和线程之间的区别和联系(最新整理)》进程是资源分配单位,线程是调度执行单位,共享资源更高效,创建线程五种方式:继承Thread、Runnable接口、匿名类、lambda,r... 目录进程和线程进程线程进程和线程的区别创建线程的五种写法继承Thread,重写run实现Runnab

在Linux中改变echo输出颜色的实现方法

《在Linux中改变echo输出颜色的实现方法》在Linux系统的命令行环境下,为了使输出信息更加清晰、突出,便于用户快速识别和区分不同类型的信息,常常需要改变echo命令的输出颜色,所以本文给大家介... 目python录在linux中改变echo输出颜色的方法技术背景实现步骤使用ANSI转义码使用tpu

C# 比较两个list 之间元素差异的常用方法

《C#比较两个list之间元素差异的常用方法》:本文主要介绍C#比较两个list之间元素差异,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. 使用Except方法2. 使用Except的逆操作3. 使用LINQ的Join,GroupJoin

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

java Long 与long之间的转换流程

《javaLong与long之间的转换流程》Long类提供了一些方法,用于在long和其他数据类型(如String)之间进行转换,本文将详细介绍如何在Java中实现Long和long之间的转换,感... 目录概述流程步骤1:将long转换为Long对象步骤2:将Longhttp://www.cppcns.c

使用Java将实体类转换为JSON并输出到控制台的完整过程

《使用Java将实体类转换为JSON并输出到控制台的完整过程》在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用JSON格式,用Java将实体类转换为J... 在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用j