【ncnn android】算法移植(七)——pytorch2onnx代码粗看

2024-06-13 09:08

本文主要是介绍【ncnn android】算法移植(七)——pytorch2onnx代码粗看,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目的:

  • 了解torch2onnx的流程
  • 了解其中的一些技术细节

1. 程序细节

  1. get_graph
    将pytorch的模型转成onnx需要的graph
  • graph, torch_out = _trace_and_get_graph_from_model(model, args, training)

  • trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True) warn_on_static_input_change(inputs_states)

  1. graph_export_onnx
proto, export_map = graph._export_onnx(params_dict, opset_version, dynamic_axes, defer_weight_export,operator_export_type, strip_doc_string, val_keep_init_as_ip)

2. 其他

  1. batchnorm
    在保存成onnx的时候,设置verbose=True,可以看有哪些属性。
%554 : Float(1, 16, 8, 8) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%550, %model.detect.context.inconv.conv.weight), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/Conv2d[conv] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0%555 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%554, %model.detect.context.inconv.bn.weight, %model.detect.context.inconv.bn.bias, %model.detect.context.inconv.bn.running_mean, %model.detect.context.inconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0%556 : Float(1, 16, 8, 8) = onnx::Relu(%555), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[inconv]/ReLU[act] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:912:0%557 : Float(1, 16, 8, 8) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%556, %model.detect.context.upconv.conv.weight), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/Conv2d[conv] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/modules/conv.py:342:0%558 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%557, %model.detect.context.upconv.bn.weight, %model.detect.context.upconv.bn.bias, %model.detect.context.upconv.bn.running_mean, %model.detect.context.upconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0%559 : Float(1, 16, 8, 8) = onnx::Relu(%558), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/ReLU[act] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:912:0

这里以batchnorm为例,说明一下:

  • 首先是pytorch中的:
    %558 : Float(1, 16, 8, 8) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%557, %model.detect.context.upconv.bn.weight, %model.detect.context.upconv.bn.bias, %model.detect.context.upconv.bn.running_mean, %model.detect.context.upconv.bn.running_var), scope: OnnxModel/DBFace[model]/DetectModule[detect]/ContextModule[context]/CBAModule[upconv]/BatchNorm2d[bn] # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:1670:0
    其中小括号中就是要保存的参数的属性有:bn.weight bn.bias bn.running_mean bn.running_var

  • ncnn中onnx2ncnn中如何读取预训练权重。

const onnx::TensorProto& scale = weights[node.input(1)];
const onnx::TensorProto& B = weights[node.input(2)];
const onnx::TensorProto& mean = weights[node.input(3)];
const onnx::TensorProto& var = weights[node.input(4)];
  • node.input(1):bn.weight
  • node.input(2):bn.bias
  • node.input(3):bn.running_mean
  • node.input(4):bn.running_var
    顺序和pytorch2onnx写入的顺序一致
  1. maxpool
  • pytorch的打印信息
%pool_hm : Float(1, 1, 8, 8) = onnx::MaxPool[ceil_mode=0, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%hm), scope: OnnxModel # /home/yangna/yangna/tool/anaconda2/envs/torch130/lib/python3.6/site-packages/torch/nn/functional.py:488:0
  • ncnn中如何读取结构参数
    因为maxpool层是没有预训练权重的,只有一些结构参数
std::string auto_pad = get_node_attr_s(node, "auto_pad");//TODO
std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
std::vector<int> strides = get_node_attr_ai(node, "strides");
std::vector<int> pads = get_node_attr_ai(node, "pads");
  • 注意:这里“auto_pad”字段和pytorch中的“ceil_model”字段是不一样的。这是因为pytorch2onnx版本和ncnn版本不对应造成的。可能ncnn20180704版时,maxpool的onnx表达中有“auto_pad”属性。

这篇关于【ncnn android】算法移植(七)——pytorch2onnx代码粗看的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1056868

相关文章

Android使用ImageView.ScaleType实现图片的缩放与裁剪功能

《Android使用ImageView.ScaleType实现图片的缩放与裁剪功能》ImageView是最常用的控件之一,它用于展示各种类型的图片,为了能够根据需求调整图片的显示效果,Android提... 目录什么是 ImageView.ScaleType?FIT_XYFIT_STARTFIT_CENTE

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

Python Transformers库(NLP处理库)案例代码讲解

《PythonTransformers库(NLP处理库)案例代码讲解》本文介绍transformers库的全面讲解,包含基础知识、高级用法、案例代码及学习路径,内容经过组织,适合不同阶段的学习者,对... 目录一、基础知识1. Transformers 库简介2. 安装与环境配置3. 快速上手示例二、核心模

Java的栈与队列实现代码解析

《Java的栈与队列实现代码解析》栈是常见的线性数据结构,栈的特点是以先进后出的形式,后进先出,先进后出,分为栈底和栈顶,栈应用于内存的分配,表达式求值,存储临时的数据和方法的调用等,本文给大家介绍J... 目录栈的概念(Stack)栈的实现代码队列(Queue)模拟实现队列(双链表实现)循环队列(循环数组

Android实现在线预览office文档的示例详解

《Android实现在线预览office文档的示例详解》在移动端展示在线Office文档(如Word、Excel、PPT)是一项常见需求,这篇文章为大家重点介绍了两种方案的实现方法,希望对大家有一定的... 目录一、项目概述二、相关技术知识三、实现思路3.1 方案一:WebView + Office Onl

Android实现两台手机屏幕共享和远程控制功能

《Android实现两台手机屏幕共享和远程控制功能》在远程协助、在线教学、技术支持等多种场景下,实时获得另一部移动设备的屏幕画面,并对其进行操作,具有极高的应用价值,本项目旨在实现两台Android手... 目录一、项目概述二、相关知识2.1 MediaProjection API2.2 Socket 网络

Android实现悬浮按钮功能

《Android实现悬浮按钮功能》在很多场景中,我们希望在应用或系统任意界面上都能看到一个小的“悬浮按钮”(FloatingButton),用来快速启动工具、展示未读信息或快捷操作,所以本文给大家介绍... 目录一、项目概述二、相关技术知识三、实现思路四、整合代码4.1 Java 代码(MainActivi

Android Mainline基础简介

《AndroidMainline基础简介》AndroidMainline是通过模块化更新Android核心组件的框架,可能提高安全性,本文给大家介绍AndroidMainline基础简介,感兴趣的朋... 目录关键要点什么是 android Mainline?Android Mainline 的工作原理关键

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指