Pytorch导出FP16 ONNX模型

2024-04-10 09:04
文章标签 模型 导出 pytorch onnx fp16

本文主要是介绍Pytorch导出FP16 ONNX模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

import torch
from torchvision.models import resnet50def main():export_fp16 = Trueexport_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = resnet50()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16
在这里插入图片描述

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB)。
在这里插入图片描述
大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>main()File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in mainmodel(x)File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_implreturn forward_call(*args, **kwargs)File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forwardx = F.conv2d(x, weight=kernel, bias=None, stride=1)RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.autocast(device_type="cuda", dtype=torch.float16):with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。
在这里插入图片描述

这篇关于Pytorch导出FP16 ONNX模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/890626

相关文章

使用EasyPoi快速导出Word文档功能的实现步骤

《使用EasyPoi快速导出Word文档功能的实现步骤》EasyPoi是一个基于ApachePOI的开源Java工具库,旨在简化Excel和Word文档的操作,本文将详细介绍如何使用EasyPoi快速... 目录一、准备工作1、引入依赖二、准备好一个word模版文件三、编写导出方法的工具类四、在Export

前端导出Excel文件出现乱码或文件损坏问题的解决办法

《前端导出Excel文件出现乱码或文件损坏问题的解决办法》在现代网页应用程序中,前端有时需要与后端进行数据交互,包括下载文件,:本文主要介绍前端导出Excel文件出现乱码或文件损坏问题的解决办法,... 目录1. 检查后端返回的数据格式2. 前端正确处理二进制数据方案 1:直接下载(推荐)方案 2:手动构造

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

JAVA实现亿级千万级数据顺序导出的示例代码

《JAVA实现亿级千万级数据顺序导出的示例代码》本文主要介绍了JAVA实现亿级千万级数据顺序导出的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 前提:主要考虑控制内存占用空间,避免出现同时导出,导致主程序OOM问题。实现思路:A.启用线程池

springboot集成easypoi导出word换行处理过程

《springboot集成easypoi导出word换行处理过程》SpringBoot集成Easypoi导出Word时,换行符n失效显示为空格,解决方法包括生成段落或替换模板中n为回车,同时需确... 目录项目场景问题描述解决方案第一种:生成段落的方式第二种:替换模板的情况,换行符替换成回车总结项目场景s

oracle 11g导入\导出(expdp impdp)之导入过程

《oracle11g导入导出(expdpimpdp)之导入过程》导出需使用SEC.DMP格式,无分号;建立expdir目录(E:/exp)并确保存在;导入在cmd下执行,需sys用户权限;若需修... 目录准备文件导入(impdp)1、建立directory2、导入语句 3、更改密码总结上一个环节,我们讲了

Android 缓存日志Logcat导出与分析最佳实践

《Android缓存日志Logcat导出与分析最佳实践》本文全面介绍AndroidLogcat缓存日志的导出与分析方法,涵盖按进程、缓冲区类型及日志级别过滤,自动化工具使用,常见问题解决方案和最佳实... 目录android 缓存日志(Logcat)导出与分析全攻略为什么要导出缓存日志?按需过滤导出1. 按

Qt中实现多线程导出数据功能的四种方式小结

《Qt中实现多线程导出数据功能的四种方式小结》在以往的项目开发中,在很多地方用到了多线程,本文将记录下在Qt开发中用到的多线程技术实现方法,以导出指定范围的数字到txt文件为例,展示多线程不同的实现方... 目录前言导出文件的示例工具类QThreadQObject的moveToThread方法实现多线程QC

SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南

《SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南》本文将基于开源项目springboot-easyexcel-batch进行解析与扩展,手把手教大家如何在SpringBo... 目录项目结构概览核心依赖百万级导出实战场景核心代码效果百万级导入实战场景监听器和Service(核心

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3