深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构

本文主要是介绍深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构

  • 1、直接打印网络参数结构
  • 2、采用torchsummary检测、查看模型参数结构
  • 3、采用netron检测、查看模型参数结构
  • 3、使用tensorboardX

1、直接打印网络参数结构

import torch.nn as nn
from torchsummary import summary
import torchclass Alexnet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, X):return self.net(X)if __name__=="__main__":device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model=Alexnet().to(device)print(model)# summary(model,(3,224,224),16)

结果输出:

Alexnet((net): Sequential((0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))(1): ReLU()(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(4): ReLU()(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU()(8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU()(10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU()(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)(13): Flatten(start_dim=1, end_dim=-1)(14): Linear(in_features=6400, out_features=4096, bias=True)(15): ReLU()(16): Dropout(p=0.5, inplace=False)(17): Linear(in_features=4096, out_features=4096, bias=True)(18): ReLU()(19): Dropout(p=0.5, inplace=False)(20): Linear(in_features=4096, out_features=10, bias=True))
)

上述方案存在的问题是:当网络参数设置存在错误时,无法检测出来

2、采用torchsummary检测、查看模型参数结构

安装torchsummary

pip install torchsummary

通常采用torchsummary打印网络结构参数时,会出现以下问题
代码:

import torch.nn as nn
from torchsummary import summaryclass Alexnet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, X):return self.net(X)net = Alexnet()
print(summary(net, (3, 224, 224), 8))

报错内容如下:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

报错原因分析:

在使用torchsummary可视化模型时候报错,报这个错误是因为类型不匹配,根据报错内容可以看出Input type为torch.FloatTensor(CPU数据类型),而weight type(即网络权重参数这些)为torch.cuda.FloatTensor(GPU数据类型)

解决方案:

将model传到GPU上便可。将代码如下修改便可正常运行

if __name__ == "__main__":from torchsummary import summarydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = UNet().to(device)	# modifyprint(model)summary(model, input_size=(3, 224, 224))

整体代码:

import torch.nn as nn
from torchsummary import summary
import torchclass Alexnet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, X):return self.net(X)if __name__=="__main__":device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model=Alexnet().to(device)# print(model)summary(model,(3,224,224),16)  # 16:表示传入的数据批次

打印结果:

----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1           [16, 96, 54, 54]          34,944ReLU-2           [16, 96, 54, 54]               0MaxPool2d-3           [16, 96, 26, 26]               0Conv2d-4          [16, 256, 26, 26]         614,656ReLU-5          [16, 256, 26, 26]               0MaxPool2d-6          [16, 256, 12, 12]               0Conv2d-7          [16, 384, 12, 12]         885,120ReLU-8          [16, 384, 12, 12]               0Conv2d-9          [16, 384, 12, 12]       1,327,488ReLU-10          [16, 384, 12, 12]               0Conv2d-11          [16, 256, 12, 12]         884,992ReLU-12          [16, 256, 12, 12]               0MaxPool2d-13            [16, 256, 5, 5]               0Flatten-14                 [16, 6400]               0Linear-15                 [16, 4096]      26,218,496ReLU-16                 [16, 4096]               0Dropout-17                 [16, 4096]               0Linear-18                 [16, 4096]      16,781,312ReLU-19                 [16, 4096]               0Dropout-20                 [16, 4096]               0Linear-21                   [16, 10]          40,970
================================================================
Total params: 46,787,978
Trainable params: 46,787,978
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 9.19
Forward/backward pass size (MB): 163.58
Params size (MB): 178.48
Estimated Total Size (MB): 351.25
----------------------------------------------------------------

3、采用netron检测、查看模型参数结构

安装netron与onnx

pip install netron onnx

代码实现:

import torch.nn as nn
import netron
import torch
from onnx import shape_inference
import onnxclass Alexnet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, X):return self.net(X)if __name__=="__main__":device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model=Alexnet()temp_image=torch.rand((1,3,224,224))# 1、利用torch.onnx.export,先将模型导出为onnx格式的文件,保存到本地./model.onnxtorch.onnx.export(model=model,args=temp_image,f='model.onnx',input_names=['image'],output_names=['feature_map'])# 2、加载进onxx模型,并推理,然后再保存覆盖原先模型onnx.save(onnx.shape_inference.infer_shapes(onnx.load("model.onnx")),"model.onnx")netron.start('model.onnx')

运行后,显示结构:
在这里插入图片描述
在这里插入图片描述

3、使用tensorboardX

在这里插入图片描述
代码实现:

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter as SummaryWriterclass Alexnet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(), nn.Linear(256 * 5 * 5, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, X):return self.net(X)net = Alexnet()
img = torch.rand((1, 3, 224, 224))
with SummaryWriter(log_dir='logs') as w:w.add_graph(net, img)

运行后,会在本地生成一个log日志文件
在命令行运行以下指令:

tensorboard --logdir ./logs --port 6006

这篇关于深度学习基础知识 使用torchsummary、netron、tensorboardX查看模参数结构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/177206

相关文章

使用Java将实体类转换为JSON并输出到控制台的完整过程

《使用Java将实体类转换为JSON并输出到控制台的完整过程》在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用JSON格式,用Java将实体类转换为J... 在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用j

一文详解PostgreSQL复制参数

《一文详解PostgreSQL复制参数》PostgreSQL作为一款功能强大的开源关系型数据库,其复制功能对于构建高可用性系统至关重要,本文给大家详细介绍了PostgreSQL的复制参数,需要的朋友可... 目录一、复制参数基础概念二、核心复制参数深度解析1. max_wal_seChina编程nders:WAL

一文详解如何查看本地MySQL的安装路径

《一文详解如何查看本地MySQL的安装路径》本地安装MySQL对于初学者或者开发人员来说是一项基础技能,但在安装过程中可能会遇到各种问题,:本文主要介绍如何查看本地MySQL安装路径的相关资料,需... 目录1. 如何查看本地mysql的安装路径1.1. 方法1:通过查询本地服务1.2. 方法2:通过MyS

Nginx使用Keepalived部署web集群(高可用高性能负载均衡)实战案例

《Nginx使用Keepalived部署web集群(高可用高性能负载均衡)实战案例》本文介绍Nginx+Keepalived实现Web集群高可用负载均衡的部署与测试,涵盖架构设计、环境配置、健康检查、... 目录前言一、架构设计二、环境准备三、案例部署配置 前端 Keepalived配置 前端 Nginx

Python logging模块使用示例详解

《Pythonlogging模块使用示例详解》Python的logging模块是一个灵活且强大的日志记录工具,广泛应用于应用程序的调试、运行监控和问题排查,下面给大家介绍Pythonlogging模... 目录一、为什么使用 logging 模块?二、核心组件三、日志级别四、基本使用步骤五、快速配置(bas

使用animation.css库快速实现CSS3旋转动画效果

《使用animation.css库快速实现CSS3旋转动画效果》随着Web技术的不断发展,动画效果已经成为了网页设计中不可或缺的一部分,本文将深入探讨animation.css的工作原理,如何使用以及... 目录1. css3动画技术简介2. animation.css库介绍2.1 animation.cs

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

Python文件操作与IO流的使用方式

《Python文件操作与IO流的使用方式》:本文主要介绍Python文件操作与IO流的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、python文件操作基础1. 打开文件2. 关闭文件二、文件读写操作1.www.chinasem.cn 读取文件2. 写

PyQt6中QMainWindow组件的使用详解

《PyQt6中QMainWindow组件的使用详解》QMainWindow是PyQt6中用于构建桌面应用程序的基础组件,本文主要介绍了PyQt6中QMainWindow组件的使用,具有一定的参考价值,... 目录1. QMainWindow 组php件概述2. 使用 QMainWindow3. QMainW

使用Python自动化生成PPT并结合LLM生成内容的代码解析

《使用Python自动化生成PPT并结合LLM生成内容的代码解析》PowerPoint是常用的文档工具,但手动设计和排版耗时耗力,本文将展示如何通过Python自动化提取PPT样式并生成新PPT,同时... 目录核心代码解析1. 提取 PPT 样式到 jsON关键步骤:代码片段:2. 应用 JSON 样式到