深度学习框架输出可视化中间层特征与类激活热力图

2023-12-29 13:28

本文主要是介绍深度学习框架输出可视化中间层特征与类激活热力图,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有时候为了分析深度学习框架的中间层特征,我们需要输出中间层特征进行分析,这里提供一个方法。

(1)输出中间特征层名字

导入所需的库并加载模型

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from collections import OrderedDict
import cv2
from models.xxx import Model  # 加载自己的模型, 这里xxx是自己模型名字
import os
device = torch.device('cuda:0')
model = Model().to(device)
print(model)

输出如下,这里我只截取了部分模型中间层输出

Model((res): ResNet50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): ResNet50DownBlock((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): ResNet50DownBlock((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(3): ResNet50DownBlock((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))(layer3): Sequential((0): ResNet50DownBlock((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): ResNet50BasicBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(2): ResNet50BasicBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(3): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(4): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(5): ResNet50DownBlock((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(extra): Sequential((0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))))

(2)加载并处理图像

img_path = './dataset//val_data/images/100_0019_0165-11.jpg'
img = Image.open(img_path)
imgarray = np.array(img)/255.0
# plt.figure(figsize=(8, 8))
# plt.imshow(imgarray)
# plt.axis('off')
# plt.show()

加载后如下

将图片处理成模型可以预测的形式

# 处理图像
transform = transforms.Compose([transforms.Resize([512, 512]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_img = transform(img).unsqueeze(0)  # unsqueeze(0)用于升维
# print(input_img.shape)   # torch.Size([1, 3, 512, 512])

(3)可视化中间层

1.定义钩子函数

# 定义钩子函数
activation = {}  # 保存获取的输出
def get_activation(name):def hook(model, input, output):activation[name] = output.detach()return hook

2.可视化中间层特征,这里选择了一个层,其他的自己可以类推

# 可视化中间层特征
checkpoint = torch.load('./checkpoint_best.pth')  # 加载一下权重
model.load_state_dict(checkpoint['model'])
model.eval()
model.res.layer1[2].register_forward_hook(get_activation('bn3'))  #resnet50 layer1中第三个模块的bn3注册钩子
input_img = input_img.to(device)  # cpu数据转一下gpu,这个看你会不会报错,我的不转会报错
_ = model(input_img)
bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的
# print(bn3.shape)  # 调试到这里基本成功了
bn3 = bn3.cpu().numpy() # 转一下numpy,  shape:(1,256, 128, 128) 
plt.figure(figsize=(8,8))
plt.imshow(bn3[0][0], cmap='gray')  # bn3[0][0]  shape:(128, 128)
plt.axis('off')
# # shape:(128, 128)
plt.show()

可视化结果

(4)利用循环输出多张图像可视化中间层

整合上面的代码,利用循环输出验证集中的多张图像中的可视化中间层

# 加载依赖包
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from collections import OrderedDict
import cv2
from models.M_SFANet import Model
import os
import glob# 定义钩子函数
activation = {}  # 保存获取的输出
def get_activation(name):def hook(model, input, output):activation[name] = output.detach()return hook# 加载模型
device = torch.device('cuda:0')
model = Model().to(device)checkpoint = torch.load('./checkpoint_best.pth')  # 加载一下权重
model.load_state_dict(checkpoint['model'])
model.eval()
model.res.layer1[2].register_forward_hook(get_activation('bn3'))  #resnet50 layer1中第三个模块的bn3注册钩子,如果需要其他层数就用其他的# 利用循环输出多个可视化中间层#读取需要输出特征的图像
DATA_PATH = f"./val_data/"
img_list = glob.glob(os.path.join(DATA_PATH, "images", "*.jpg"))    # image 路径
img_list.sort()
for idx in range(0, len(img_list)):img_name = img_list[idx].split('/')[-1].split('.')[0]  # 获取文件名img = Image.open(img_list[idx])  # 可以读到图片imgarray = np.array(img)/255.0# 处理图像transform = transforms.Compose([transforms.Resize([512, 512]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])input_img = transform(img).unsqueeze(0)  # unsqueeze(0)用于升维input_img = input_img.to(device)  # cpu数据转一下gpu,这个看你会不会报错,我的会报错_ = model(input_img)bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的bn3 = bn3.cpu().numpy() plt.figure(figsize=(8,8))plt.imshow(bn3[0][0], cmap='jet')  # bn3[0][0]  shape:(128, 128)plt.axis('off')# # shape:(128, 128)plt.savefig('./feature_out/res50/layer1/{}_res50_layer1'.format(img_name), bbox_inches='tight', pad_inches=0.05, dpi=300)

保存至文件夹中如下

---------------------------------------------------更新于2023.1121.28 -----------------------------------------

(5)利用循环输出多张图像类激活热力图

使用类激活热力图,能观察模型对图像识别的关键位置。

这里接着上面的获得的特征图进一步得到类激活热力图

接着上面获取到bn3,代码如下

    bn3 = activation['bn3']   # 结果将保存在activation字典中  bn3输出<class 'torch.Tensor'>, tensor是无法用plt正常显示的'''以下代码用于输出特征图bn3 = bn3.cpu().numpy() plt.figure(figsize=(8,8))plt.imshow(bn3[0][0], cmap='jet')  # bn3[0][0]  shape:(128, 128)plt.axis('off')# # shape:(128, 128)plt.savefig('./feature_out/res50/layer4/{}_res50_layer4'.format(img_name), bbox_inches='tight', pad_inches=0.05, dpi=300)'''# 将特征图用类热力图形式叠加到原图中bn3 = bn3[0][0].cpu().numpy()bn3 = np.maximum(bn3, 0)bn3 /= np.max(bn3)# plt.matshow(bn3)# plt.show()# img1 = cv2.imread('./dataset/ShanghaiTech/part_A_final/val_data/images/100_0019_0165-11.jpg')img1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)  # PIL Image转一下cv2bn3 = cv2.resize(bn3, (img1.shape[1], img1.shape[0]))bn3 = np.uint8(255 * bn3)bn3 = cv2.applyColorMap(bn3, cv2.COLORMAP_JET)heat_img = cv2.addWeighted(img1, 1, bn3, 0.5, 0)cv2.imwrite('./heatmap_out/res50/layer1/{}_res50_layer1.jpg'.format(str(img_name)), heat_img)

输出如下

这篇关于深度学习框架输出可视化中间层特征与类激活热力图的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/549569

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Java 缓存框架 Caffeine 应用场景解析

《Java缓存框架Caffeine应用场景解析》文章介绍Caffeine作为高性能Java本地缓存框架,基于W-TinyLFU算法,支持异步加载、灵活过期策略、内存安全机制及统计监控,重点解析其... 目录一、Caffeine 简介1. 框架概述1.1 Caffeine的核心优势二、Caffeine 基础2

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

GSON框架下将百度天气JSON数据转JavaBean

《GSON框架下将百度天气JSON数据转JavaBean》这篇文章主要为大家详细介绍了如何在GSON框架下实现将百度天气JSON数据转JavaBean,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言一、百度天气jsON1、请求参数2、返回参数3、属性映射二、GSON属性映射实战1、类对象映

从基础到高级详解Python数值格式化输出的完全指南

《从基础到高级详解Python数值格式化输出的完全指南》在数据分析、金融计算和科学报告领域,数值格式化是提升可读性和专业性的关键技术,本文将深入解析Python中数值格式化输出的相关方法,感兴趣的小伙... 目录引言:数值格式化的核心价值一、基础格式化方法1.1 三种核心格式化方式对比1.2 基础格式化示例