查看神经网络中间层特征矩阵及卷积核参数

2024-01-19 12:20

本文主要是介绍查看神经网络中间层特征矩阵及卷积核参数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

可视化feature maps以及kernel weights,使用alexnet模型进行演示。

1. 查看中间层特征矩阵

alexnet模型,修改了向前传播

import torch
from torch import nn
from torch.nn import functional as F# 对花图像数据进行分类
class AlexNet(nn.Module):def __init__(self,num_classes=1000,init_weights=False, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.conv1 = nn.Conv2d(3,48,11,4,2)self.pool1 = nn.MaxPool2d(3,2)self.conv2 = nn.Conv2d(48,128,5,padding=2)self.pool2 = nn.MaxPool2d(3,2)self.conv3 = nn.Conv2d(128,192,3,padding=1)self.conv4 = nn.Conv2d(192,192,3,padding=1)self.conv5 = nn.Conv2d(192,128,3,padding=1)self.pool3 = nn.MaxPool2d(3,2)self.fc1 = nn.Linear(128*6*6,2048)self.fc2 = nn.Linear(2048,2048)self.fc3 = nn.Linear(2048,num_classes)# 是否进行初始化# 其实我们并不需要对其进行初始化,因为在pytorch中,对我们对卷积及全连接层,自动使用了凯明初始化方法进行了初始化if init_weights:self._initialize_weights()def forward(self,x):outputs = []  # 定义一个列表,返回我们要查看的哪一层的输出特征矩阵x = self.conv1(x)outputs.append(x)x = self.pool1(F.relu(x,inplace=True))x = self.conv2(x)outputs.append(x)x = self.pool2(F.relu(x,inplace=True))x = self.conv3(x)outputs.append(x)x = F.relu(x,inplace=True)x = F.relu(self.conv4(x),inplace=True)x = self.pool3(F.relu(self.conv5(x),inplace=True))x = x.view(-1,128*6*6)x = F.dropout(x,p=0.5)x = F.relu(self.fc1(x),inplace=True)x = F.dropout(x,p=0.5)x = F.relu(self.fc2(x),inplace=True)x = self.fc3(x)# for name,module in self.named_children():#     x = module(x)#     if name == ["conv1","conv2","conv3"]:#         outputs.append(x)return outputs# 初始化权重def _initialize_weights(self):for m in self.modules():if isinstance(m,nn.Conv2d):# 凯明初始化 - 何凯明nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m,nn.Linear):nn.init.normal_(m.weight, 0,0.01)  # 使用正态分布给权重赋值进行初始化nn.init.constant_(m.bias,0)

拿到向前传播的结果,对特征图进行可视化,这里,我们使用训练好的模型,直接加载模型参数。

注意,要使用与训练时相同的数据预处理。

import matplotlib.pyplot as plt
from torchvision import transforms
import alexnet_model
import torch
from PIL import Image
import numpy as np
from alexnet_model import AlexNet# AlexNet 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)image = Image.open("./images/yjx.jpg")
image = transform(image)
image = image.unsqueeze(0)with torch.no_grad():output = model(image)for feature_map in output:# (N,C,W,H) -> (C,W,H)im = np.squeeze(feature_map.detach().numpy())# (C,W,H) -> (W,H,C)im = np.transpose(im,[1,2,0])plt.figure()# 展示当前层的前12个通道for i in range(12):ax = plt.subplot(3,4,i+1) # i+1: 每个图的索引plt.imshow(im[:,:,i],cmap='gray')plt.show()

结果:

在这里插入图片描述


2. 查看卷积核参数

import matplotlib.pyplot as plt
import numpy as np
import torchfrom AlexNet.model import AlexNet# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)weights_keys = model.state_dict().keys()
for key in weights_keys:if "num_batches_tracked" in key:continueweight_t = model.state_dict()[key].numpy()weight_mean = weight_t.mean()weight_std = weight_t.std(ddof=1)weight_min = weight_t.min()weight_max = weight_t.max()print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean, weight_std, weight_min, weight_max))weight_vec = np.reshape(weight_t,[-1])plt.hist(weight_vec,bins=50)plt.title(key)plt.show()

结果:

在这里插入图片描述

这篇关于查看神经网络中间层特征矩阵及卷积核参数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/622416

相关文章

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型

Linux中查看操作系统及其版本信息的多种方法

《Linux中查看操作系统及其版本信息的多种方法》在服务器运维或者部署系统中,经常需要确认服务器的系统版本、cpu信息等,在Linux系统中,有多种方法可以查看操作系统及其版本信息,以下是一些常用的方... 目录1. lsb_pythonrelease 命令2. /etc/os-release 文件3. h

在Android中使用WebView在线查看PDF文件的方法示例

《在Android中使用WebView在线查看PDF文件的方法示例》在Android应用开发中,有时我们需要在客户端展示PDF文件,以便用户可以阅读或交互,:本文主要介绍在Android中使用We... 目录简介:1. WebView组件介绍2. 在androidManifest.XML中添加Interne

SpringBoot 获取请求参数的常用注解及用法

《SpringBoot获取请求参数的常用注解及用法》SpringBoot通过@RequestParam、@PathVariable等注解支持从HTTP请求中获取参数,涵盖查询、路径、请求体、头、C... 目录SpringBoot 提供了多种注解来方便地从 HTTP 请求中获取参数以下是主要的注解及其用法:1

HTTP 与 SpringBoot 参数提交与接收协议方式

《HTTP与SpringBoot参数提交与接收协议方式》HTTP参数提交方式包括URL查询、表单、JSON/XML、路径变量、头部、Cookie、GraphQL、WebSocket和SSE,依据... 目录HTTP 协议支持多种参数提交方式,主要取决于请求方法(Method)和内容类型(Content-Ty

Linux实现查看某一端口是否开放

《Linux实现查看某一端口是否开放》文章介绍了三种检查端口6379是否开放的方法:通过lsof查看进程占用,用netstat区分TCP/UDP监听状态,以及用telnet测试远程连接可达性... 目录1、使用lsof 命令来查看端口是否开放2、使用netstat 命令来查看端口是否开放3、使用telnet

python中的显式声明类型参数使用方式

《python中的显式声明类型参数使用方式》文章探讨了Python3.10+版本中类型注解的使用,指出FastAPI官方示例强调显式声明参数类型,通过|操作符替代Union/Optional,可提升代... 目录背景python函数显式声明的类型汇总基本类型集合类型Optional and Union(py

Go语言使用Gin处理路由参数和查询参数

《Go语言使用Gin处理路由参数和查询参数》在WebAPI开发中,处理路由参数(PathParameter)和查询参数(QueryParameter)是非常常见的需求,下面我们就来看看Go语言... 目录一、路由参数 vs 查询参数二、Gin 获取路由参数和查询参数三、示例代码四、运行与测试1. 测试编程路

MySQL的触发器全解析(创建、查看触发器)

《MySQL的触发器全解析(创建、查看触发器)》MySQL触发器是与表关联的存储程序,当INSERT/UPDATE/DELETE事件发生时自动执行,用于维护数据一致性、日志记录和校验,优点包括自动执行... 目录触发器的概念:创建触www.chinasem.cn发器:查看触发器:查看当前数据库的所有触发器的定

Python lambda函数(匿名函数)、参数类型与递归全解析

《Pythonlambda函数(匿名函数)、参数类型与递归全解析》本文详解Python中lambda匿名函数、灵活参数类型和递归函数三大进阶特性,分别介绍其定义、应用场景及注意事项,助力编写简洁高效... 目录一、lambda 匿名函数:简洁的单行函数1. lambda 的定义与基本用法2. lambda