深度学习框架输出可视化中间层特征与类激活热力图

2023-12-29 13:28

本文主要是介绍深度学习框架输出可视化中间层特征与类激活热力图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有时候为了分析深度学习框架的中间层特征,我们需要输出中间层特征进行分析,这里提供一个方法。

(1)输出中间特征层名字

导入所需的库并加载模型

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from collections import OrderedDict
import cv2
from models.xxx import Model  # 加载自己的模型, 这里xxx是自己模型名字
import os
device = torch.device('cuda:0')
model = Model().to(device)
print(model)

输出如下,这里我只截取了部分模型中间层输出

Model((res): ResNet50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): ResNet50DownBlock((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): ResNet50DownBlock((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(3): ResNet50DownBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer3): Sequential((0): ResNet50DownBlock((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(3): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(5): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))

(2)加载并处理图像

img_path = './dataset//val_data/images/100_0019_0165-11.jpg'
img = Image.open(img_path)
imgarray = np.array(img)/255.0
# plt.figure(figsize=(8, 8))
# plt.imshow(imgarray)
# plt.axis('off')
# plt.show()

加载后如下

将图片处理成模型可以预测的形式

# 处理图像
transform = transforms.Compose([transforms.Resize([512, 512]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_img = transform(img).unsqueeze(0)  # unsqueeze(0)用于升维
# print(input_img.shape)   # torch.Size([1, 3, 512, 512])

(3)可视化中间层

1.定义钩子函数

# 定义钩子函数
activation = {}  # 保存获取的输出
def get_activation(name):def hook(model, input, output):activation[name] = output.detach()return hook

2.可视化中间层特征,这里选择了一个层,其他的自己可以类推

# 可视化中间层特征
checkpoint = torch.load('./checkpoint_best.pth')  # 加载一下权重
model.load_state_dict(checkpoint['model'])
model.eval()
model.res.layer1[2].register_forward_hook(get_activation('bn3'))  #resnet50 layer1中第三个模块的bn3注册钩子
input_img = input_img.to(device)  # cpu数据转一下gpu,这个看你会不会报错,我的不转会报错
_ = model(input_img)
bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的
# print(bn3.shape)  # 调试到这里基本成功了
bn3 = bn3.cpu().numpy() # 转一下numpy,  shape:(1,256, 128, 128) 
plt.figure(figsize=(8,8))
plt.imshow(bn3[0][0], cmap='gray')  # bn3[0][0]  shape:(128, 128)
plt.axis('off')
# # shape:(128, 128)
plt.show()

可视化结果

(4)利用循环输出多张图像可视化中间层

整合上面的代码,利用循环输出验证集中的多张图像中的可视化中间层

# 加载依赖包
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from collections import OrderedDict
import cv2
from models.M_SFANet import Model
import os
import glob# 定义钩子函数
activation = {}  # 保存获取的输出
def get_activation(name):def hook(model, input, output):activation[name] = output.detach()return hook# 加载模型
device = torch.device('cuda:0')
model = Model().to(device)checkpoint = torch.load('./checkpoint_best.pth')  # 加载一下权重
model.load_state_dict(checkpoint['model'])
model.eval()
model.res.layer1[2].register_forward_hook(get_activation('bn3'))  #resnet50 layer1中第三个模块的bn3注册钩子,如果需要其他层数就用其他的# 利用循环输出多个可视化中间层#读取需要输出特征的图像
DATA_PATH = f"./val_data/"
img_list = glob.glob(os.path.join(DATA_PATH, "images", "*.jpg"))    # image 路径
img_list.sort()
for idx in range(0, len(img_list)):img_name = img_list[idx].split('/')[-1].split('.')[0]  # 获取文件名img = Image.open(img_list[idx])  # 可以读到图片imgarray = np.array(img)/255.0# 处理图像transform = transforms.Compose([transforms.Resize([512, 512]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])input_img = transform(img).unsqueeze(0)  # unsqueeze(0)用于升维input_img = input_img.to(device)  # cpu数据转一下gpu,这个看你会不会报错,我的会报错_ = model(input_img)bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的bn3 = bn3.cpu().numpy() plt.figure(figsize=(8,8))plt.imshow(bn3[0][0], cmap='jet')  # bn3[0][0]  shape:(128, 128)plt.axis('off')# # shape:(128, 128)plt.savefig('./feature_out/res50/layer1/{}_res50_layer1'.format(img_name), bbox_inches='tight', pad_inches=0.05, dpi=300)

保存至文件夹中如下

---------------------------------------------------更新于2023.1121.28 -----------------------------------------

(5)利用循环输出多张图像类激活热力图

使用类激活热力图,能观察模型对图像识别的关键位置。

这里接着上面的获得的特征图进一步得到类激活热力图

接着上面获取到bn3,代码如下

    bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的'''以下代码用于输出特征图bn3 = bn3.cpu().numpy() plt.figure(figsize=(8,8))plt.imshow(bn3[0][0], cmap='jet')  # bn3[0][0]  shape:(128, 128)plt.axis('off')# # shape:(128, 128)plt.savefig('./feature_out/res50/layer4/{}_res50_layer4'.format(img_name), bbox_inches='tight', pad_inches=0.05, dpi=300)'''# 将特征图用类热力图形式叠加到原图中bn3 = bn3[0][0].cpu().numpy()bn3 = np.maximum(bn3, 0)bn3 /= np.max(bn3)# plt.matshow(bn3)# plt.show()# img1 = cv2.imread('./dataset/ShanghaiTech/part_A_final/val_data/images/100_0019_0165-11.jpg')img1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)  # PIL Image转一下cv2bn3 = cv2.resize(bn3, (img1.shape[1], img1.shape[0]))bn3 = np.uint8(255 * bn3)bn3 = cv2.applyColorMap(bn3, cv2.COLORMAP_JET)heat_img = cv2.addWeighted(img1, 1, bn3, 0.5, 0)cv2.imwrite('./heatmap_out/res50/layer1/{}_res50_layer1.jpg'.format(str(img_name)), heat_img)

输出如下

这篇关于深度学习框架输出可视化中间层特征与类激活热力图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/549569

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Python数据分析与可视化的全面指南(从数据清洗到图表呈现)

《Python数据分析与可视化的全面指南(从数据清洗到图表呈现)》Python是数据分析与可视化领域中最受欢迎的编程语言之一,凭借其丰富的库和工具,Python能够帮助我们快速处理、分析数据并生成高质... 目录一、数据采集与初步探索二、数据清洗的七种武器1. 缺失值处理策略2. 异常值检测与修正3. 数据

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)

《使用Python和Matplotlib实现可视化字体轮廓(从路径数据到矢量图形)》字体设计和矢量图形处理是编程中一个有趣且实用的领域,通过Python的matplotlib库,我们可以轻松将字体轮廓... 目录背景知识字体轮廓的表示实现步骤1. 安装依赖库2. 准备数据3. 解析路径指令4. 绘制图形关键

C++ HTTP框架推荐(特点及优势)

《C++HTTP框架推荐(特点及优势)》:本文主要介绍C++HTTP框架推荐的相关资料,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Crow2. Drogon3. Pistache4. cpp-httplib5. Beast (Boos

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

SpringBoot基础框架详解

《SpringBoot基础框架详解》SpringBoot开发目的是为了简化Spring应用的创建、运行、调试和部署等,使用SpringBoot可以不用或者只需要很少的Spring配置就可以让企业项目快... 目录SpringBoot基础 – 框架介绍1.SpringBoot介绍1.1 概述1.2 核心功能2

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

使用Vue-ECharts实现数据可视化图表功能

《使用Vue-ECharts实现数据可视化图表功能》在前端开发中,经常会遇到需要展示数据可视化的需求,比如柱状图、折线图、饼图等,这类需求不仅要求我们准确地将数据呈现出来,还需要兼顾美观与交互体验,所... 目录前言为什么选择 vue-ECharts?1. 基于 ECharts,功能强大2. 更符合 Vue