Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

2023-10-08 03:20

本文主要是介绍Pytorch复习笔记--导出Onnx模型为动态输入和静态输入,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch API

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入
input_name = 'input'
output_name = 'output'
torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx",opset_version=11,input_names=[input_name],output_names=[output_name],dynamic_axes={input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torch
import torch.nn as nnclass Model_Net(nn.Module):def __init__(self):super(Model_Net, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)def forward(self, data):data = self.layer1(data)return dataif __name__ == "__main__":# 设置输入参数Batch_size = 8Channel = 3Height = 256Width = 256input_data = torch.rand((Batch_size, Channel, Height, Width))# 实例化模型model = Model_Net()# 导出为静态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name])# 导出为动态输入torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx",opset_version=11,input_names=[input_name],output_names=[output_name],dynamic_axes={input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'},output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})

4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型

import numpy as np
import onnx
import onnxruntimeif __name__ == "__main__":input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32)input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32)# 导入 Onnx 模型Onnx_file = "./Dynamics_InputNet.onnx"Model = onnx.load(Onnx_file)onnx.checker.check_model(Model) # 验证Onnx模型是否准确# 使用 onnxruntime 推理model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])input_name = model.get_inputs()[0].nameoutput_name = model.get_outputs()[0].nameoutput1 = model.run([output_name], {input_name:input_data1})output2 = model.run([output_name], {input_name:input_data2})print('output1.shape: ', np.squeeze(np.array(output1), 0).shape)print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

这篇关于Pytorch复习笔记--导出Onnx模型为动态输入和静态输入的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/162362

相关文章

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

Android 缓存日志Logcat导出与分析最佳实践

《Android缓存日志Logcat导出与分析最佳实践》本文全面介绍AndroidLogcat缓存日志的导出与分析方法,涵盖按进程、缓冲区类型及日志级别过滤,自动化工具使用,常见问题解决方案和最佳实... 目录android 缓存日志(Logcat)导出与分析全攻略为什么要导出缓存日志?按需过滤导出1. 按

通过配置nginx访问服务器静态资源的过程

《通过配置nginx访问服务器静态资源的过程》文章介绍了图片存储路径设置、Nginx服务器配置及通过http://192.168.206.170:8007/a.png访问图片的方法,涵盖图片管理与服务... 目录1.图片存储路径2.nginx配置3.访问图片方式总结1.图片存储路径2.nginx配置

Qt中实现多线程导出数据功能的四种方式小结

《Qt中实现多线程导出数据功能的四种方式小结》在以往的项目开发中,在很多地方用到了多线程,本文将记录下在Qt开发中用到的多线程技术实现方法,以导出指定范围的数字到txt文件为例,展示多线程不同的实现方... 目录前言导出文件的示例工具类QThreadQObject的moveToThread方法实现多线程QC

SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南

《SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南》本文将基于开源项目springboot-easyexcel-batch进行解析与扩展,手把手教大家如何在SpringBo... 目录项目结构概览核心依赖百万级导出实战场景核心代码效果百万级导入实战场景监听器和Service(核心

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

shell脚本批量导出redis key-value方式

《shell脚本批量导出rediskey-value方式》为避免keys全量扫描导致Redis卡顿,可先通过dump.rdb备份文件在本地恢复,再使用scan命令渐进导出key-value,通过CN... 目录1 背景2 详细步骤2.1 本地docker启动Redis2.2 shell批量导出脚本3 附录总

go动态限制并发数量的实现示例

《go动态限制并发数量的实现示例》本文主要介绍了Go并发控制方法,通过带缓冲通道和第三方库实现并发数量限制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 目录带有缓冲大小的通道使用第三方库其他控制并发的方法因为go从语言层面支持并发,所以面试百分百会问到

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

SpringBoot集成EasyPoi实现Excel模板导出成PDF文件

《SpringBoot集成EasyPoi实现Excel模板导出成PDF文件》在日常工作中,我们经常需要将数据导出成Excel表格或PDF文件,本文将介绍如何在SpringBoot项目中集成EasyPo... 目录前言摘要简介源代码解析应用场景案例优缺点分析类代码方法介绍测试用例小结前言在日常工作中,我们经