pytorch转onnx转mnn并验证

2023-12-18 02:58
文章标签 验证 pytorch onnx mnn

本文主要是介绍pytorch转onnx转mnn并验证,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch训练的模型在实际使用时往往需要转换成onnx或mnn部署,训练好的模型需先转成onnx:

import sys
import argparse
import torch
import torchvision
import torch.onnxfrom  mobilenetv2  import MobileNetV2if __name__ == '__main__':model=MobileNetV2(2)model_path='./model/mobilenetv2.mdl'model.eval()model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))dummy_input = torch.randn([1,3,32,32])   #batch,channel,height,widthtorch.onnx.export(model, dummy_input, model_path.replace('mdl', 'onnx'), verbose=True, input_names=['input'], output_names=['output'],opset_version=11)print('Done!')

转换成功后,再转mnn,通过MNN转换工具:

.MNNConvert -f ONNX --modelFile XXX.onnx --MNNModel XXX.mnn --bizCode biz

测试pytorch的结果:

import argparse
import os
from glob import glob
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()# create modelmodel = MobileNetV2(2).to(device)model.load_state_dict(torch.load('models/best-mobilenetv2.mdl',map_location=torch.device('cpu')))model.eval()img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)x = img.to(device)outputs = model(x)# 输出概率最大的类别_, indices = torch.max(outputs, 1)percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100perc = percentage[int(indices)].item()print('predicted:', perc)print('id:', int(indices))end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

测试ONNX的结果,与pytorch结果一致:

import argparse
import os
from glob import glob
import onnxruntime
import onnx
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef to_numpy(tensor):return tensor.detach().cpu.numpy() if tensor.requires_grad else tensor.cpu().numpy()def main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()model = 'models/best-mobilenetv2.onnx'onet_seeion=onnxruntime.InferenceSession(model)img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)x = img.to(device)inputs={onet_seeion.get_inputs()[0].name:to_numpy(img)}outputs=onet_seeion.run(None,inputs)print(outputs)end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

测试mnn的结果,与前面的结果一致,但是速度快了近20倍:

import argparse
import os
from glob import glob
import MNN
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from  PIL import Image
from  mobilenetv2  import MobileNetV2
import numpy as npdef parse_args():parser = argparse.ArgumentParser()parser.add_argument('--image_path', default=None,help='the path of imgae')args = parser.parse_args()return argsdef to_numpy(tensor):return tensor.detach().cpu.numpy() if tensor.requires_grad else tensor.cpu().numpy()def main():args = parse_args()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")start = cv2.getTickCount()model = 'models/best-mobilenetv2.mnn'interpreter = MNN.Interpreter(model)mnn_session = interpreter.createSession()input_tensor = interpreter.getSessionInput(mnn_session)img = args.image_pathcut_size = 48tf = transforms.Compose([lambda x: Image.open(x).convert('RGB'),  # string path= > image datatransforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img).unsqueeze(0)tmp_input = MNN.Tensor((1, 3, 32, 32), MNN.Halide_Type_Float, \to_numpy(img[0]), MNN.Tensor_DimensionType_Caffe)print(tmp_input.getShape())# print(tmp_input.getData())print(input_tensor.copyFrom(tmp_input))input_tensor.printTensorData()interpreter.runSession(mnn_session)output_tensor = interpreter.getSessionOutput(mnn_session, 'output')output_tensor.printTensorData()output_data = np.array(output_tensor.getData())print('mnn result is:', output_data)print("output belong to class: {}".format(np.argmax(output_tensor.getData())))end = cv2.getTickCount()during = (end - start) / cv2.getTickFrequency()print("avg_time:", during)if __name__ == '__main__':main()

用c++进行mnn重写测试,结果一致,这样就可以编库了:

// mnn_test.cpp : 定义控制台应用程序的入口点。#include "stdafx.h"
#include <iostream>
#include <opencv2/opencv.hpp>
#include <MNN/Interpreter.hpp>
#include <MNN/MNNDefine.h>
#include <MNN/Tensor.hpp>
#include <MNN/ImageProcess.hpp>
#include <memory>#define IMAGE_VERIFY_SIZE 32
#define CLASSES_SIZE 2
#define INPUT_NAME "input"
#define OUTPUT_NAME "output"cv::Mat BGRToRGB(cv::Mat img)
{cv::Mat image(img.rows, img.cols, CV_8UC3);for (int i = 0; i<img.rows; ++i) {cv::Vec3b *p1 = img.ptr<cv::Vec3b>(i);cv::Vec3b *p2 = image.ptr<cv::Vec3b>(i);for (int j = 0; j<img.cols; ++j) {p2[j][2] = p1[j][0];p2[j][1] = p1[j][1];p2[j][0] = p1[j][2];}}return image;
}int main(int argc, char* argv[]) {if (argc < 2) {printf("Usage:\n\t%s mnn_model_path image_path\n", argv[0]);return -1;}// create net and sessionconst char *mnn_model_path = argv[1];const char *image_path = argv[2];auto mnnNet = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromFile(mnn_model_path));MNN::ScheduleConfig netConfig;netConfig.type = MNN_FORWARD_CPU;netConfig.numThread = 4;auto session = mnnNet->createSession(netConfig);auto input = mnnNet->getSessionInput(session, INPUT_NAME);if (input->elementSize() <= 4) {mnnNet->resizeTensor(input, { 1, 3, IMAGE_VERIFY_SIZE, IMAGE_VERIFY_SIZE });mnnNet->resizeSession(session);}std::cout << "input shape: " << input->shape()[0] << " " << input->shape()[1] << " " << input->shape()[2] << " " << input->shape()[3] << std::endl;// preprocess imageMNN::Tensor givenTensor(input, MNN::Tensor::CAFFE);// const int inputSize = givenTensor.elementSize();// std::cout << inputSize << std::endl;auto inputData = givenTensor.host<float>();cv::Mat bgr_image = cv::imread(image_path);bgr_image = BGRToRGB(bgr_image);cv::Mat norm_image;cv::resize(bgr_image, norm_image, cv::Size(IMAGE_VERIFY_SIZE, IMAGE_VERIFY_SIZE));for (int k = 0; k < 3; k++) {for (int i = 0; i < norm_image.rows; i++) {for (int j = 0; j < norm_image.cols; j++) {const auto src = norm_image.at<cv::Vec3b>(i, j)[k];auto dst = 0.0;if (k == 0) dst = (float(src) / 255.0f - 0.485) / 0.229;if (k == 1) dst = (float(src) / 255.0f - 0.456) / 0.224;if (k == 2) dst = (float(src) / 255.0f - 0.406) / 0.225;inputData[k * IMAGE_VERIFY_SIZE * IMAGE_VERIFY_SIZE + i * IMAGE_VERIFY_SIZE + j] = dst;}}}input->copyFromHostTensor(&givenTensor);double st = cvGetTickCount();// run sessionmnnNet->runSession(session);double et = cvGetTickCount() - st;et = et / ((double)cvGetTickFrequency() * 1000);std::cout << " speed: " << et << " ms" << std::endl;// get output dataauto output = mnnNet->getSessionOutput(session, OUTPUT_NAME);// std::cout << "output shape: " << output->shape()[0] << " " << output->shape()[1] << std::endl;auto output_host = std::make_shared<MNN::Tensor>(output, MNN::Tensor::CAFFE);output->copyToHostTensor(output_host.get());auto values = output_host->host<float>();// post processstd::vector<float> output_values;auto exp_sum = 0.0;auto max_index = 0;for (int i = 0; i < CLASSES_SIZE; i++) {if (values[i] > values[max_index]) max_index = i;output_values.push_back(values[i]);exp_sum += std::exp(values[i]);}std::cout << "output: " << output_values[0]<<","<< output_values[1] << std::endl;std::cout << "id: " << max_index << std::endl;std::cout << "prob: " << std::exp(output_values[max_index]) / exp_sum << std::endl;system("pause");return 0;
}

 

这篇关于pytorch转onnx转mnn并验证的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/506796

相关文章

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

MySQL 主从复制部署及验证(示例详解)

《MySQL主从复制部署及验证(示例详解)》本文介绍MySQL主从复制部署步骤及学校管理数据库创建脚本,包含表结构设计、示例数据插入和查询语句,用于验证主从同步功能,感兴趣的朋友一起看看吧... 目录mysql 主从复制部署指南部署步骤1.环境准备2. 主服务器配置3. 创建复制用户4. 获取主服务器状态5

Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式

《Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式》本文详细介绍如何使用Java通过JDBC连接MySQL数据库,包括下载驱动、配置Eclipse环境、检测数据库连接等关键步骤,... 目录一、下载驱动包二、放jar包三、检测数据库连接JavaJava 如何使用 JDBC 连接 mys

Spring Security中用户名和密码的验证完整流程

《SpringSecurity中用户名和密码的验证完整流程》本文给大家介绍SpringSecurity中用户名和密码的验证完整流程,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定... 首先创建了一个UsernamePasswordAuthenticationTChina编程oken对象,这是S

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效